예제 #1
0
    def __init__(self, listOfLogfiles):
        '''
        Initializes the concatinator with a list of Logfiles
        to get a list of iterators concerning to the merged logs
        call concate()
        '''

        self.__listeners = []
        events = {
            Concatenator.finish:
            re.compile('Optimization Done.'),
            Concatenator.interrupt:
            re.compile(
                "((?:Snapshotting solver state to (?:binary proto|HDF5) file ).*)"
            ),
            Concatenator.resume:
            re.compile("((?<= Resuming from ).*)"),
        }

        for logFile in listOfLogfiles:
            log_iter = self.__createIterFromFile(logFile)
            parser = Parser(log_iter, events)
            listener = Concatenator.SnapshotListener(logFile)
            parser.addListener(listener)
            try:
                parser.parseLog()
                self.__listeners.append(listener)
            except:
                callerId = Log.getCallerId("Error Parsing Log File:" +
                                           str(logFile))
                Log.error("", callerId)
예제 #2
0
    def fetchParserData(self):
        if self._assertConnection() and self.transaction is not None:
            msg = {
                "key": Protocol.SESSION,
                "subkey": SessionProtocol.FETCHPARSERDATA
            }
            self.transaction.send(msg)
            ret = self.transaction.asyncRead(
                staging=True, attr=("subkey", SessionProtocol.FETCHPARSERDATA))
            if ret:
                if ret["status"]:
                    parserRows = ret["ParserRows"]
                    parserKeys = ret["ParserKeys"]
                    parserHandle = ret["ParserHandle"]
                    parserLog = ret["ParserLogs"]

                    for phase, key in parserKeys:
                        self.parser.sendParserRegisterKeys(phase, key)

                    for phase, row in parserRows:
                        self.parser.sendParserUpdate(phase, row)

                    for event, line, groups in parserHandle:
                        self.parser.sendParserHandle(event, line, groups)

                    for log, error in parserLog:
                        if error:
                            Log.error(log, self.getCallerId())
                        else:
                            Log.log(log, self.getCallerId())

                    return
                else:
                    self._handleErrors(ret["error"])
        self._handleErrors(["Failed to fetch parser data from host."])
예제 #3
0
    def loadFiles(self, filenames):
        """
        Loads every log file in the list of filenames and adds them to the list of logs.
        """
        for i in range(len(filenames)):
            f = filenames[i]
            if not os.path.exists(f):
                Log.error("External log not found: " + f,
                          Log.getCallerId("plotter"))
            else:
                parser = Parser(open(f, 'r'), OrderedDict(),
                                Log.getCallerId(str(f)))
                head, tail = os.path.split(str(f))
                logId = "external_" + tail
                logName = "[ext] " + tail
                self.putLog(logId, parser, logName=logName)
                parser.parseLog()

                # This is for saving the loaded logs in the project file
                # (create the necessary docks if they don't exist)
                if self.__settings is None:
                    self.__settings = {"logFiles": {logId: f}}
                elif "logFiles" not in self.__settings:
                    self.__settings["logFiles"] = {logId: f}
                else:
                    self.__settings["logFiles"][logId] = f
예제 #4
0
    def snapshot(self):
        """ Create a snapshot from the training state.

        Return
            True if the snapshot was created
            False if the snapshot could not be created
        """
        if self.getState() is State.RUNNING:
            if self.proc:
                self.last_solverstate = None
                try:
                    self.proc.send_signal(signal.SIGHUP)
                except Exception as e:
                    Log.error('Failed to take snapshot: ' + str(e),
                              self.getCallerId())
                    return False
                Log.log(
                    'Snapshot was saved for session ' +
                    self.getRunLogFileName(True) + '', self.getCallerId())
                return True
            else:
                Log.error(
                    'Could not take a session snapshot in state ' +
                    str(self.getState()), self.getCallerId())
        return False
예제 #5
0
    def __init__(self, layers, view, vertical=False):
        self.vertical = vertical
        if (vertical):
            NodeSort.Node.OFFSET = 70
        else:
            NodeSort.Node.OFFSET = 30
        self.__hasCycle = False

        self.columnWidths = []
        self.rowHeights = []

        # all abstract nodes (type Node) that exist
        self.nodes = []
        # all gui nodes that exist
        self.guiNodes = []

        if (type(layers) is list):
            Log.error("wrong type of layers", callerId)
            return
        # fills nodes and guiNodes
        self.getAbstractNodes(layers.values())

        if (self.__hasCycle):
            Log.error("There is a Cycle in that graph, cannot be handled yet",
                      callerId)
            return
        if (not layers.values()):
            return
        self.sort(self.nodes, view)
예제 #6
0
 def printLog(self, log, error=False):
     if self.log_id is not None:
         if not error:
             Log.log(log, self.log_id)
         else:
             Log.error(log, self.log_id)
     else:
         self.printLogSignl.emit(log, error)
예제 #7
0
 def open(self):
     '''open the database from the set path.'''
     if self._env:
         self.close()
     if self._path:
         if os.path.exists(self._path +
                           "data.mdb") or os.path.exists(self._path +
                                                         "/data.mdb"):
             self._env = lmdb.open(self._path, max_dbs=2)
         else:
             Log.error("Dir is not valid LMDB: " + self._path, self.logid)
예제 #8
0
 def open(self):
     '''open the database from the set path.'''
     if self._db:
         self.close()
     if self._path:
         if os.path.exists(self._path +
                           "CURRENT") or os.path.exists(self._path +
                                                        "/CURRENT"):
             self._db = leveldb.LevelDB(self._path)
         else:
             Log.error("Dir is not valid LEVELDB: " + self._path,
                       self.logid)
예제 #9
0
    def pause(self):
        """ Pause the process.

        Return
            True if the process was paused
            False if the pause failed
        """
        if self.getState() is State.RUNNING:
            if self.proc:
                self.snapshot()
                # give the caffe process a second to save the state
                time.sleep(1)
                try:
                    if self.proc:
                        self.proc.kill()
                except Exception as e:
                    Log.error('Pausing session failed: ' + str(e))
                try:
                    if self.tee:
                        self.tee.kill()
                except Exception:
                    pass
                self.proc = None
                self.tee = None
                self.last_solverstate = None
                snap = self._getLastSnapshotFromSnapshotDirectory()
                if snap is not None:
                    self.last_solverstate = os.path.basename(
                        self._getLastSnapshotFromSnapshotDirectory())
                    regex_iter = re.compile(
                        'iter_([\d]+)\.solverstate[\.\w-]*$')
                    iter_match = regex_iter.search(self.last_solverstate)
                    if iter_match:
                        self.iteration = int(iter_match.group(1))
                    self.iterationChanged.emit()
                    self.setState(State.PAUSED)
                    Log.log(
                        'Session ' + self.getRunLogFileName(True) +
                        ' was paused', self.getCallerId())
                else:
                    self.setState(State.WAITING)
                    Log.log(
                        'Session ' + self.getRunLogFileName(True) +
                        ' was halted', self.getCallerId())
                self.save()
                return True
            else:
                Log.error(
                    'Could not pause a session in state ' +
                    str(self.getState()), self.getCallerId())
                self.setState(State.UNDEFINED)
        return False
예제 #10
0
 def __ensureDirectory(self, directory):
     """ Creates a directory if it does not exist.
     """
     if directory == '':
         return
     if not os.path.exists(directory):
         try:
             os.makedirs(directory)
             Log.log('Created directory: ' + directory, self.getCallerId())
         except Exception as e:
             Log.error(
                 'Failed to create directory ' + directory + ')' + str(e),
                 self.getCallerId())
예제 #11
0
 def open(self):
     '''open the database from the set path.'''
     if self._db:
         self.close()
     if self._path:
         if os.path.exists(self._path):
             try:
                 self._db = h5.File(self._path, 'r')
             except:
                 Log.error("File not valid HDF5: " + self._path, self._logid)
                 self._db = None
         elif not self._pathOfHdf5Txt:
             Log.error("File does not exist: " + self._path, self._logid)
예제 #12
0
    def _getCurrentDatum(self):
        from backend.caffe.path_loader import PathLoader
        caffe = PathLoader().importCaffe()
        if self._cursor:
            raw_datum = self._cursor.value()
            datum = caffe.proto.caffe_pb2.Datum()

            try:
                datum.ParseFromString(raw_datum)
            except:
                Log.error("LMDB does not contain valid data: " + self._path,
                          self.logid)
                return None
            return datum
예제 #13
0
    def createRemoteSession(self, remote, state_dictionary=None):
        """use this only to create entirely new sessions. to load existing use the loadRemoteSession command"""

        msg = {"key": Protocol.GETCAFFEVERSIONS}
        reply = sendMsgToHost(remote[0], remote[1], msg)
        if reply:
            remoteVersions = reply["versions"]
            if len(remoteVersions) <= 0:
                msgBox = QMessageBox(
                    QMessageBox.Warning, "Error",
                    "Cannot create remote session on a host witout a caffe-version"
                )
                msgBox.addButton("Ok", QMessageBox.NoRole)
                msgBox.exec_()
                return None

        sid = self.getNextSessionId()
        msg = {
            "key": Protocol.CREATESESSION,
            "pid": self.projectId,
            "sid": sid
        }

        layers = []
        for layer in state_dictionary["network"]["layers"]:
            layers.append(state_dictionary["network"]["layers"][layer]
                          ["parameters"]["type"])

        msg["layers"] = layers

        ret = sendMsgToHost(remote[0], remote[1], msg)
        if ret:
            if ret["status"]:
                uid = ret["uid"]
            else:
                for e in ret["error"]:
                    Log.error(e, self.getCallerId())
                return None
        else:
            Log.error('Failed to create remote session! No connection to Host',
                      self.getCallerId())
            return None

        session = ClientSession(self, remote, uid, sid)
        if state_dictionary is not None:
            session.state_dictionary = state_dictionary
        self.__sessions[sid] = session
        self.newSession.emit(sid)
        return sid
예제 #14
0
 def setActiveSID(self, sid):
     validSIDs = self.getValidSIDs()
     if sid in validSIDs:
         self.__activeSID = sid
         self.activeSessionChanged.emit(sid)
     else:
         Log.error(
             "Could not set active session to " + str(sid) +
             " valid Session-IDs are: " +
             ", ".join([str(s) for s in validSIDs]), self.getCallerId())
         if not self.__activeSID:
             self.__activeSID = validSIDs[-1] if len(
                 validSIDs) > 0 else None
             Log.log("Active session set to " + str(self.__activeSID),
                     self.getCallerId())
    def __updateLayerList(self):
        """ Update the layer list with available layers found in the net
        description.
        """

        # getLayers

        if self.__currentSessionId is not None:
            session = self.__sessionDict[self.__currentSessionId]
            if isinstance(session, ClientSession):
                netInternal = session.loadInternalNetFile()
                currentNetwork = loader.loadNet(netInternal)
                layerNames = map(
                    lambda layer: layer["parameters"]["name"],
                    filter(
                        lambda layer: layer["type"].name() in self.ALLOWED_LAYERTYPES,
                        currentNetwork["layers"].values()
                    )
                )
                layerNameList = sorted(layerNames)
            else:
                try:
                    currentNetworkPath = session.getInternalNetFile()
                    file = open(currentNetworkPath, 'r')
                    currentNetwork = loader.loadNet(file.read())
                    # get all the names of the layers, which match the desired type
                    layerNames = map(
                        lambda layer: layer["parameters"]["name"],
                        filter(
                            lambda layer: layer["type"].name() in self.ALLOWED_LAYERTYPES,
                            currentNetwork["layers"].values()
                        )
                    )
                    layerNameList = sorted(layerNames)
                except IOError:
                    callerId = Log.getCallerId('weight-plotter')
                    Log.error("Could not open the network of this session.", callerId)
                    layerNameList =[]
        else:
            layerNameList =[]
        # updates the layer Combobox with the current layers
        self.layerComboBox.replaceItems(layerNameList)
        if self.__currentLayerName is None or self.__currentLayerName not in layerNameList:
            if not layerNameList == []:
                self.__currentLayerName = layerNameList[-1]
            else:
                self.__currentLayerName = None
        self.layerComboBox.setCurrentText(self.__currentLayerName)
예제 #16
0
    def proceed(self, snapshot=None):
        """ Continue training from the (last) snapshot.

        Return
            True if the process was continued
            False if the continuation failed
        """
        if self.getState() is State.PAUSED:
            self.__ensureDirectory(self.getSnapshotDirectory())
            self.__ensureDirectory(self.logs)

            if snapshot is None:
                snapshot = self.getLastSnapshot()
            self.rid += 1
            try:
                self.getParser().setLogging(True)
                self.proc = Popen([
                    caffeVersions.getVersionByName(
                        self.project.getCaffeVersion()).getBinarypath(),
                    'train', '-solver',
                    self.getSolver(), '-snapshot', snapshot
                ],
                                  stdout=PIPE,
                                  stderr=STDOUT,
                                  cwd=self.getDirectory())
                try:
                    self.tee = Popen(
                        ['tee', '-a', self.getRunLogFileName()],
                        stdin=self.proc.stdout,
                        stdout=PIPE)
                except Exception as e:
                    # continue without tee
                    Log.error('Failed to start tee: ' + str(e),
                              self.getCallerId())
                self.setState(State.RUNNING)
                Log.log(
                    'Session ' + self.getRunLogFileName(True) +
                    ' was proceeded', self.getCallerId())
                self.__startParsing()
                return True
            except Exception as e:
                # check if caffe root exists
                Log.error('Failed to continue session: ' + str(e),
                          self.getCallerId())
                if os.file.exists(
                        caffeVersions.getVersionByName(
                            self.project.getCaffeVersion()).getBinarypath()
                ) is False:
                    Log.error(
                        'CAFFE_BINARY directory does not exists: ' +
                        caffe_bin +
                        '! Please set CAFFE_BINARY to run a session.',
                        self.getCallerId())
        elif self.getState() in (State.FAILED, State.FINISHED):
            Log.error(
                'Could not continue a session in state ' +
                str(self.getState()), self.getCallerId())
        return False
예제 #17
0
    def save(self, includeProtoTxt=False):
        """Saves the current session to a json file. If includeProtoTxt is True, prototxt files are saved as well."""
        toSave = {
            "SessionState": self.state,
            "Iteration": self.iteration,
            "MaxIter": self.max_iter
        }
        toSave["ProjectID"] = self.project.projectId

        self.__ensureDirectory(self.directory)
        Log.log("Saving current Session status to disk.", self.getCallerId())
        if self.last_solverstate:
            toSave["LastSnapshot"] = self.last_solverstate
        if self.getPretrainedWeights():
            toSave["PretrainedWeights"] = self.getPretrainedWeights()
        if self.state_dictionary:
            serializedDict = copy.deepcopy(self.state_dictionary)
            if includeProtoTxt:
                if "solver" in self.state_dictionary:
                    solver = self.buildSolverPrototxt()
                    with open(self.getSolver(log=False), 'w') as f:
                        f.write(solver)
                else:
                    Log.error(
                        "Could not save a solver prototxt file, because no solver settings are defined.",
                        self.getCallerId())

            if "network" in serializedDict:
                if includeProtoTxt:
                    net = self.buildNetPrototxt(internalVersion=False)
                    with open(self.getOriginalNetFile(log=False), 'w') as f:
                        f.write(net)
                    net = self.buildNetPrototxt(internalVersion=True)
                    with open(self.getInternalNetFile(log=False), 'w') as f:
                        f.write(net)
                if "layers" in serializedDict["network"]:
                    layers = serializedDict["network"]["layers"]
                    for id in layers:
                        del layers[id]["type"]
            else:
                Log.error(
                    "Could not save the network state because no state was defined.",
                    self.getCallerId())

            toSave["NetworkState"] = serializedDict

        with open(baristaSessionFile(self.directory), "w") as f:
            json.dump(toSave, f, sort_keys=True, indent=4)
 def __saveImage(self):
     """ Gives the user the opportunity to save the image on disk.
         The user can choose from allowed extension, and is alerted if the entered extension is not valid.
         If no extension is entered by the user, the extension is set to png.
     """
     allowedExtensions = [".png", ".bmp", ".jpeg", ".jpg"]
     allowedAsString = str(reduce(lambda x, y: x + y, map(lambda x: " *" + x, allowedExtensions)))
     callerId = self.getCallerId()
     try:
         filename = ""
         # While the entered extension does not matches the allowd extensions
         while not self.__validateFilename(filename, allowedExtensions):
             fileDialog = QFileDialog()
             fileDialog.setDefaultSuffix("png")
             filenameArray = fileDialog.getSaveFileName(
                 self,
                 "Save File",
                 filename,
                 allowedAsString
             )
             filename = filenameArray[0]
             if filename != "":
                 _, extension = os.path.splitext(filename)
                 # If no extension has been entered, append .png
                 if extension == "":
                     filename += ".png"
             else:
                 # If user clicks on abort, leave the loop
                 break
             # Show an alert message, when an unknown extension has been entered
             if not  self.__validateFilename(filename, allowedExtensions):
                 msg = QMessageBox()
                 msg.setIcon(QMessageBox.Warning)
                 msg.setWindowTitle("Warning")
                 msg.setText("Please enter a valid extension:\n"
                             + allowedAsString)
                 msg.exec_()
         if filename != "":
             if self.__currentNet:
                 image = calculateConvWeights(self.__currentNet, self.__currentLayerName)
                 self.canvasWidget.saveImage(image, filename)
                 Log.log("Saved image under " + filename, callerId)
             else:
                 Log.error("Try to save image of weights without a network being loaded.", callerId)
     except Exception as e:
         Log.error("Saving the file failed. " + str(e), callerId)
예제 #19
0
    def _getFirstDatum(self):
        from backend.caffe.path_loader import PathLoader
        caffe = PathLoader().importCaffe()
        iter = self._getIter()
        if iter:
            for key, value in iter:
                raw_datum = value
                datum = caffe.proto.caffe_pb2.Datum()

                try:
                    datum.ParseFromString(raw_datum)
                except:
                    Log.error(
                        "LEVELDB does not contain valid data: " + self._path,
                        self.logid)
                    return None
                return datum
예제 #20
0
 def delete(self):
     """ Delete the session directory and disconnect signals.
     """
     self.pause()
     try:
         shutil.rmtree(self.getDirectory())
     except Exception as e:
         Log.error('Could not remove session directory: ' + str(e),
                   self.getCallerId())
     try:
         self.stateChanged.disconnect()
         self.iterationChanged.disconnect()
         self.snapshotAdded.disconnect()
         self.project.deleteSession.emit(self.sid)
     except Exception as e:
         pass
     Log.removeCallerId(self.caller_id, False)
예제 #21
0
    def _modifyH5TxtFile(self, dir, state=None):
        net = None
        if state:
            net = state["network"]
        else:
            session = self.getActiveSession()
            if session:
                state_dict = session.state_dictionary
                if state_dict:
                    if "network" in state_dict:
                        net = state_dict["network"]

        if net:
            h = helper.DictHelper(net)
            for layerId, layer in net.get("layers", {}).iteritems():

                paramKey = "hdf5_data_param.source"

                if h.layerParameterIsSet(layerId, paramKey):
                    paramValue = h.layerParameter(layerId, paramKey)

                    if paramValue is not None and not os.path.isabs(
                            paramValue):
                        newFilename = str(layerId) + ".txt"
                        newFilepath = os.path.join(dir, newFilename)
                        oldPath = os.path.join(
                            dir,
                            os.path.join(os.pardir,
                                         os.path.join(os.pardir, paramValue)))
                        if os.path.exists(oldPath):
                            with open(newFilepath, "w") as f:
                                lines = [
                                    line.rstrip('\n') for line in open(oldPath)
                                ]
                                for line in lines:
                                    if line is not "":
                                        if line[:1] == '.':
                                            line = os.path.join(
                                                os.pardir,
                                                os.path.join(os.pardir, line))
                                        f.write("\n" + line)
                        else:
                            Log.error(
                                'Failed to copy hdf5txt file. File does not exists: '
                                + oldPath, self.getCallerId())
 def __getNetwork(self, sess_id=None, snap_id=None):
     """ Return the caffe network of the current session and snapshot.
     """
     if sess_id is None:
         sess_id = self.__currentSessionId
     if snap_id is None:
         snap_id = self.__currentSnapshotId
     if sess_id and snap_id:
         # Creates a Dictionary of Sessions Ids, which contains Snapshot Ids
         # which point to Layer Id which point to
         # already opened Networks.
         net = None
         if sess_id in self.__alreadyOpenSnapshots:
             session_snapshots = self.__alreadyOpenSnapshots[sess_id]
             if snap_id in session_snapshots:
                 net = session_snapshots[snap_id]
         if net:
             # cached net found
             return net
         else:
             # snapshot was accessed for the first time
             # create and cache the net
             session = self.__sessionDict[sess_id]
             snapName = snap_id.replace('solverstate', 'caffemodel')
             if isinstance(session, ClientSession):
                 net = session.loadNetParameter(snapName)
                 if net is None:
                     return
             else:
                 snapshotPath = session.getSnapshotDirectory()
                 snapshotPath = str(os.path.join(snapshotPath, snapName))
                 if not os.path.exists(snapshotPath):
                     Log.error('Snapshot file '+snapshotPath+' does not exist!', self.getCallerId())
                     return
                 net = loadNetParameter(snapshotPath)
             if net is not None:
                 if sess_id not in self.__alreadyOpenSnapshots.keys():
                     self.__alreadyOpenSnapshots[sess_id] = {snap_id: net}
                 else:
                     self.__alreadyOpenSnapshots[sess_id][snap_id] = net
                 return net
             else:
                 # Show a warning message
                 Log.error('The hdf5 snapshot format is not supported for the weight visualization! '
                           'This can be changed by setting the snapshot_format parameter in the solver properties.', self.getCallerId())
예제 #23
0
 def open(self):
     if self._db:
         self.close()
     if self._path is not None:
         if os.path.exists(self._path):
             lines = [line.rstrip('\n') for line in open(self._path)]
             hdf5Count = 0
             for line in lines:
                 if line is not "":
                     if line[:1] == '.':
                         line = self._makepath(line)
                     i = len(self._db)
                     self._db.append(Hdf5Input(pathOfHdf5Txt=True))
                     self._db[i].setPath(line)
                     self._db[i].open()
                     if self._db[i].isOpen():
                         hdf5Count += 1
             if hdf5Count == 0:
                 self._db = None
                 Log.error("File contained no valid paths to HDF5 files: {}".format(self._path), self._logid)
예제 #24
0
    def _readFile(self):
        """Read the content of the caffe.proto file."""
        import backend.barista.caffe_versions as caffeVersions
        # get path to caffe.proto file
        caffeProtoFilePath = caffeVersions.getAvailableVersions(
        )[0].getProtopath()

        if os.path.isfile(caffeProtoFilePath):
            # open caffe.proto file
            with open(caffeProtoFilePath) as f:
                # read complete file at once
                content = f.read()
        else:
            callerId = Log.getCallerId("parameter_descriptions")
            Log.error(
                "ERROR: Unable to load caffe.proto file. Parameter descriptions and deprecated flags "
                "will be left empty. Did you configure the environment variable CAFFE_ROOT?",
                callerId)
            content = ""

        return content
예제 #25
0
 def parseOldLogs(self):
     """ Parse all log files in the log directory.
     """
     locked = self.lock.acquire()
     if locked is False:
         return
     try:
         if self.parse_old:
             self.parse_old = False
             log_files = {}
             regex_filename = re.compile('[\d]+\.([\d]+)\.log$')
             for entry in os.listdir(self.getLogs()):
                 filename_match = regex_filename.search(entry)
                 if filename_match:
                     # key files by run id
                     try:
                         run_id = int(filename_match.group(1))
                         log_files[run_id] = entry
                     except:
                         pass
             log_list = []
             for run_id in sorted(log_files.keys()):
                 log_file = os.path.join(self.getLogs(), log_files[run_id])
                 log_list.append(log_file)
             con = Concatenator(log_list)
             logs = con.concate()
             for log in logs:
                 try:
                     self.getParser().addLogStream(log)
                 except Exception as e:
                     Log.error(
                         'Failed to parse log file ' +
                         self.getLogFileName(True) + ": " + str(e),
                         self.getCallerId())
     except Exception as e:
         Log.error('Failed to parse old log ' + str(e), self.getCallerId())
     finally:
         if locked:
             self.lock.release()
예제 #26
0
    def parsingFinished(self):
        """ Called when the parser has processed all available streams.
        """

        if self.proc is not None:
            # Wait for caffe process, kill tee and respond to return code
            assert self.state is State.RUNNING
            rcode = self.proc.wait()
            self.proc = None
            try:
                self.tee.kill()
            except Exception:
                pass
            self.tee = None
            if rcode is 0:
                self.setFinished()
            else:
                self.setState(State.FAILED)
                Log.error('Session failed with return code ' + str(rcode),
                          self.getCallerId())

        self.setParserInitialized()
예제 #27
0
    def createNewSession(self, caffemodel=None):
        """ Create a new session and start the training.

        Create a new session, add new parser to plotter and create a new
        session widget for the session
        """
        if checkMinimumTrainingRequirements(self.project, self.sessionGui):
            sessionID = self.project.createSession(
                state_dictionary=self.sessionGui.mainWindow.networkManager.
                getStateDictionary())
            if sessionID:
                session = self.project.getSessions[sessionID]
                session.snapshotAdded.connect(
                    lambda: self.wPlotter.updatePlotter(self.project.
                                                        getSessions()))
                self.updateWeightPlotter()
                # Since this is not set at startup, the parsing has already been done, thus the session should send signals
                session.setParserInitialized()
                session.start(caffemodel=caffemodel)
                self.filterState()
            else:
                Log.error('Failed to create session!', self.getCallerId())
예제 #28
0
 def createNewSessionFromSnapshot(self, solverstate, old_session):
     """ Create a new session as a clone of the given session, start with the
     snapshot state.
     """
     if old_session.isRemote():
         sessionID = self.project.cloneRemoteSession(
             solverstate, old_session)
     else:
         sessionID = self.project.cloneSession(solverstate, old_session)
     session = self.project.getSession(sessionID)
     if session:
         session.setState(State.PAUSED)
         self.sessionConnects(session)
         session.snapshotAdded.connect(lambda: self.wPlotter.updatePlotter(
             self.project.getSessions()))
         self.setActiveSID(sessionID)
     else:
         cid = None
         if old_session:
             cid = old_session.getCallerId()
         else:
             cid = self.getCallerId()
         Log.error('Failed to clone session!', cid)
예제 #29
0
    def cloneRemoteSession(self, oldSolverstate, oldSession):
        """
         Starts the cloning process for a remote session and creates the corr. local session upon success

         oldSolverstate: solverstate produced by the snapshot from which the clone should be created
         oldSession: session from which a clone should be created (type ClientSession)
        """
        # validate the given session and solverstate
        if oldSolverstate is None:
            Log.error('Could not find solver', self.getCallerId())
            return None
        if oldSession is None:
            Log.error('Failed to create session!', self.getCallerId())
            return None

        sid = self.getNextSessionId()
        # call the remote host to invoke cloning; @see cloneSession in server_session_manager.py
        msg = {
            "key": Protocol.CLONESESSION,
            "pid": self.projectId,
            "sid": sid,
            "old_uid": oldSession.uid,
            "old_solverstate": oldSolverstate
        }
        ret = sendMsgToHost(oldSession.remote[0], oldSession.remote[1], msg)
        # receive and validate answer
        if ret:
            if ret["status"]:
                uid = ret["uid"]
            else:
                for e in ret["error"]:
                    Log.error(e, self.getCallerId())
                return None
        else:
            Log.error('Failed to clone remote session! No connection to Host',
                      self.getCallerId())
            return None
        # Create a corr. local session and copy (if available) the state-dictionary to maintain
        # solver/net etc.
        session = ClientSession(self, oldSession.remote, uid, sid)
        if hasattr(oldSession, 'state_dictionary'):
            session.state_dictionary = oldSession.state_dictionary
        self.__sessions[sid] = session
        self.newSession.emit(sid)
        return sid
예제 #30
0
 def checkFiles(self):
     """ Check for the existence of the session directories and files.
     """
     if os.path.exists(self.directory) is False:
         Log.error('Session directory does not exists: ' + self.directory,
                   self.getCallerId())
     if os.path.exists(self.logs) is False:
         Log.error('Log directory does not exists: ' + self.logs,
                   self.getCallerId())
     if os.path.exists(self.snapshot_dir) is False:
         Log.error(
             'Snapshot directory does not exists: ' + self.snapshot_dir,
             self.getCallerId())
     if os.file.exists(
             caffeVersions.getVersionByName(
                 self.project.getCaffeVersion())) is False:
         Log.error(
             'Caffe binary does not exists: ' +
             self.project.getCaffeVersion(), self.getCallerId())