Example #1
0
def makeTrainCfg(run):
    trainCfg = StructDict()
    if run.runId == 1:
        trainCfg.blkGrpRefs = [{'run': 1, 'phase': 1}, {'run': 1, 'phase': 2}]
    elif run.runId == 2:
        trainCfg.blkGrpRefs = [{'run': 1, 'phase': 2}, {'run': 2, 'phase': 1}]
    else:
        trainCfg.blkGrpRefs = [{
            'run': run.runId - 1,
            'phase': 1
        }, {
            'run': run.runId,
            'phase': 1
        }]
    return trainCfg
    def runRun(self, runId, scanNum=-1):
        # Setup output directory and output file
        runDataDir = os.path.join(self.dirs.dataDir, 'run' + str(runId))
        if not os.path.exists(runDataDir):
            os.makedirs(runDataDir)
        outputInfo = StructDict()
        outputInfo.runId = runId
        outputInfo.classOutputDir = os.path.join(runDataDir, 'classoutput')
        if not os.path.exists(outputInfo.classOutputDir):
            os.makedirs(outputInfo.classOutputDir)
        outputInfo.logFilename = os.path.join(runDataDir,
                                              'fileprocessing_py.txt')
        outputInfo.logFileHandle = open(outputInfo.logFilename, 'w+')
        if self.webpipes is not None:
            outputInfo.webpipes = self.webpipes
        if self.webUseRemoteFiles:
            outputInfo.webUseRemoteFiles = True
            remoteRunDataDir = os.path.join(self.dirs.remoteDataDir,
                                            'run' + str(runId))
            outputInfo.remoteClassOutputDir = os.path.join(
                remoteRunDataDir, 'classoutput')
            outputInfo.remoteLogFilename = os.path.join(
                remoteRunDataDir, 'fileprocessing_py.txt')
        # Get patterns design file for this run
        patterns = None
        if self.webUseRemoteFiles and self.cfg.session.getPatternsFromControlRoom:
            fileRegex = getPatternsFileRegex(self.cfg.session,
                                             self.dirs.remoteDataDir,
                                             runId,
                                             addRunDir=True)
            getNewestFileCmd = wcutils.getNewestFileReqStruct(fileRegex)
            retVals = wcutils.clientWebpipeCmd(self.webpipes, getNewestFileCmd)
            if retVals.statusCode != 200:
                raise RequestError('runRun: statusCode not 200: {}'.format(
                    retVals.statusCode))
            patterns = retVals.data
            logging.info("Using Remote Patterns file: %s", retVals.filename)
            print("Using remote patterns {}".format(retVals.filename))
        else:
            patterns, filename = getLocalPatternsFile(self.cfg.session,
                                                      self.dirs.dataDir, runId)
            print("Using patterns {}".format(filename))
        run = createRunConfig(self.cfg.session, patterns, runId, scanNum)
        validateRunCfg(run)
        self.id_fields.runId = run.runId
        logging.log(DebugLevels.L4, "Run: %d, scanNum %d", runId, run.scanNum)
        if self.cfg.session.rtData:
            # Check if images already exist and warn and ask to continue
            firstFileName = self.getDicomFileName(run.scanNum, 1)
            if os.path.exists(firstFileName):
                logging.log(DebugLevels.L3, "Dicoms already exist")
                skipCheck = self.cfg.session.skipConfirmForReprocess
                if skipCheck is None or skipCheck is False:
                    resp = input(
                        'Files with this scan number already exist. Do you want to continue? Y/N [N]: '
                    )
                    if resp.upper() != 'Y':
                        outputInfo.logFileHandle.close()
                        return
            else:
                logging.log(DebugLevels.L3, "Dicoms - waiting for")
        elif self.cfg.session.replayMatFileMode or self.cfg.session.validate:
            idx = getRunIndex(self.cfg.session, runId)
            if idx >= 0 and len(self.cfg.session.validationModels) > idx:
                run.validationModel = os.path.join(
                    self.dirs.dataDir, self.cfg.session.validationModels[idx])
            else:
                raise ValidationError(
                    "Insufficient config runs or validationModels specified: "
                    "runId {}, validationModel idx {}", runId, idx)
            if idx >= 0 and len(self.cfg.session.validationData) > idx:
                run.validationDataFile = os.path.join(
                    self.dirs.dataDir, self.cfg.session.validationData[idx])
            else:
                raise ValidationError(
                    "Insufficient config runs or validationDataFiles specified: "
                    "runId {}, validationData idx {}", runId, idx)

        # ** Experimental Parameters ** #
        run.seed = time.time()
        if run.runId > 1:
            run.rtfeedback = 1
        else:
            run.rtfeedback = 0

        runCfg = copy_toplevel(run)
        reply = self.sendCmdExpectSuccess(MsgEvent.StartRun, runCfg)
        outputReplyLines(reply.fields.outputlns, outputInfo)

        if self.cfg.session.replayMatFileMode and not self.cfg.session.rtData:
            # load previous patterns data for this run
            p = utils.loadMatFile(run.validationDataFile)
            run.replay_data = p.patterns.raw

        # Begin BlockGroups (phases)
        for blockGroup in run.blockGroups:
            self.id_fields.blkGrpId = blockGroup.blkGrpId
            blockGroupCfg = copy_toplevel(blockGroup)
            logging.log(DebugLevels.L4, "BlkGrp: %d", blockGroup.blkGrpId)
            reply = self.sendCmdExpectSuccess(MsgEvent.StartBlockGroup,
                                              blockGroupCfg)
            outputReplyLines(reply.fields.outputlns, outputInfo)
            for block in blockGroup.blocks:
                self.id_fields.blockId = block.blockId
                blockCfg = copy_toplevel(block)
                logging.log(DebugLevels.L4, "Blk: %d", block.blockId)
                reply = self.sendCmdExpectSuccess(MsgEvent.StartBlock,
                                                  blockCfg)
                outputReplyLines(reply.fields.outputlns, outputInfo)
                for TR in block.TRs:
                    self.id_fields.trId = TR.trId
                    fileNum = TR.vol + run.disdaqs // run.TRTime
                    logging.log(DebugLevels.L3, "TR: %d, fileNum %d", TR.trId,
                                fileNum)
                    if self.cfg.session.rtData:
                        # Assuming the output file volumes are still 1's based
                        trVolumeData = self.getNextTRData(run, fileNum)
                        if trVolumeData is None:
                            if TR.trId == 0:
                                errStr = "First TR {} of run {} missing data, aborting...".format(
                                    TR.trId, runId)
                                raise RTError(errStr)
                            logging.warn(
                                "TR {} missing data, sending empty data".
                                format(TR.trId))
                            TR.data = np.full((self.cfg.session.nVoxels),
                                              np.nan)
                            reply = self.sendCmdExpectSuccess(
                                MsgEvent.TRData, TR)
                            continue
                        TR.data = applyMask(trVolumeData,
                                            self.cfg.session.roiInds)
                    else:
                        # TR.vol is 1's based to match matlab, so we want vol-1 for zero based indexing
                        TR.data = run.replay_data[TR.vol - 1]
                    processingStartTime = time.time()
                    imageAcquisitionTime = 0.0
                    pulseBroadcastTime = 0.0
                    trStartTime = 0.0
                    gotTTLTime = False
                    if (self.cfg.session.enforceDeadlines is not None
                            and self.cfg.session.enforceDeadlines is True):
                        # capture TTL pulse from scanner to calculate next deadline
                        trStartTime = self.ttlPulseClient.getTimestamp()
                        if trStartTime == 0 or imageAcquisitionTime > run.TRTime:
                            # Either no TTL Pulse time signal or stale time signal
                            #   Approximate trStart as current time minus 500ms
                            #   because scan reconstruction takes about 500ms
                            gotTTLTime = False
                            trStartTime = time.time() - 0.5
                            # logging.info("Approx TR deadline: {}".format(trStartTime))
                        else:
                            gotTTLTime = True
                            imageAcquisitionTime = time.time() - trStartTime
                            pulseBroadcastTime = trStartTime - self.ttlPulseClient.getServerTimestamp(
                            )
                            # logging.info("TTL TR deadline: {}".format(trStartTime))
                        # Deadline is TR_Start_Time + time between TRs +
                        #  clockSkew adjustment - 1/2 Max Net Round_Trip_Time -
                        #  Min RTT because clock skew calculation can be off
                        #  by the RTT used for calculation which is Min RTT.
                        TR.deadline = (trStartTime + self.cfg.clockSkew +
                                       run.TRTime - (0.5 * self.cfg.maxRTT) -
                                       self.cfg.minRTT)
                    reply = self.sendCmdExpectSuccess(MsgEvent.TRData, TR)
                    processingEndTime = time.time()
                    missedDeadline = False
                    if (reply.fields.missedDeadline is not None
                            and reply.fields.missedDeadline is True):
                        # TODO - store reply.fields.threadId in order to get completed reply later
                        # TODO - add a message type that retrieves previous thread results
                        missedDeadline = True
                    else:
                        # classification result
                        outputPredictionFile(reply.fields.predict, outputInfo)

                    # log the TR processing time
                    serverProcessTime = processingEndTime - processingStartTime
                    elapsedTRTime = 0.0
                    if gotTTLTime is True:
                        elapsedTRTime = time.time() - trStartTime
                    logStr = "TR:{}:{}:{:03}, fileNum {}, server_process_time {:.3f}s, " \
                             "elapsedTR_time {:.3f}s, image_time {:.3f}s, " \
                             "pulse_time {:.3f}s, gotTTLPulse {}, missed_deadline {}, " \
                             "dicom_arrival {:.5f}" \
                             .format(runId, block.blockId, TR.trId, fileNum,
                                     serverProcessTime, elapsedTRTime,
                                     imageAcquisitionTime, pulseBroadcastTime,
                                     gotTTLTime, missedDeadline, processingStartTime)
                    logging.log(DebugLevels.L3, logStr)
                    outputReplyLines(reply.fields.outputlns, outputInfo)
                del self.id_fields.trId
                # End Block
                if self.webpipes is not None:
                    cmd = {'cmd': 'subjectDisplay', 'bgcolor': '#808080'}
                    wcutils.clientWebpipeCmd(self.webpipes, cmd)
                reply = self.sendCmdExpectSuccess(MsgEvent.EndBlock, blockCfg)
                outputReplyLines(reply.fields.outputlns, outputInfo)
            del self.id_fields.blockId
            reply = self.sendCmdExpectSuccess(MsgEvent.EndBlockGroup,
                                              blockGroupCfg)
            outputReplyLines(reply.fields.outputlns, outputInfo)
            # self.retrieveBlkGrp(self.id_fields.sessionId, self.id_fields.runId, self.id_fields.blkGrpId)
        del self.id_fields.blkGrpId
        # End Run
        if self.webpipes is not None:
            # send instructions to subject window display
            cmd = {
                'cmd': 'subjectDisplay',
                'text': 'Waiting for next run to start...'
            }
            wcutils.clientWebpipeCmd(self.webpipes, cmd)
        # Train the model for this Run
        trainCfg = StructDict()
        if run.runId == 1:
            trainCfg.blkGrpRefs = [{
                'run': 1,
                'phase': 1
            }, {
                'run': 1,
                'phase': 2
            }]
        elif run.runId == 2:
            trainCfg.blkGrpRefs = [{
                'run': 1,
                'phase': 2
            }, {
                'run': 2,
                'phase': 1
            }]
        else:
            trainCfg.blkGrpRefs = [{
                'run': run.runId - 1,
                'phase': 1
            }, {
                'run': run.runId,
                'phase': 1
            }]
        outlns = []
        outlns.append('*********************************************')
        outlns.append("Train Model {} {}".format(trainCfg.blkGrpRefs[0],
                                                 trainCfg.blkGrpRefs[1]))
        outputReplyLines(outlns, outputInfo)
        processingStartTime = time.time()
        reply = self.sendCmdExpectSuccess(MsgEvent.TrainModel, trainCfg)
        processingEndTime = time.time()
        # log the model generation time
        logStr = "Model:{} training time {:.3f}s\n".format(
            runId, processingEndTime - processingStartTime)
        logging.log(DebugLevels.L3, logStr)
        outputReplyLines(reply.fields.outputlns, outputInfo)
        reply = self.sendCmdExpectSuccess(MsgEvent.EndRun, runCfg)
        outputReplyLines(reply.fields.outputlns, outputInfo)
        if self.cfg.session.retrieveServerFiles:
            self.retrieveRunFiles(runId)
        del self.id_fields.runId
        outputInfo.logFileHandle.close()