def train_server_yolo_model(self, projectName, jsonData, modelTrainVersion): loggerUtils.info("start server yolo model train thread [end]") # 2.训练 # 2.1. 准备数据 trainDataDict, valDataDict = datasetService.loadTrainData( modelTrainVersion.dmtvid, modelTrainVersion['ds_dl_list']) # 2.2 组装,保存 训练参数 trainConfig = jsonData["advancedSet"] isUsePreTraindModel = trainConfig["isUsePreTraindModel"] modelTrainConfig = detectModelTrainConfig.getDetectModelTrainConfig( dmtvid=modelTrainVersion.dmtvid, cfg=ConstantUtils.getModelCfgPath(isUsePreTraindModel, modelTrainVersion.dmPrecision), weights=ConstantUtils.getModelWeightsPath( isUsePreTraindModel, modelTrainVersion.dmPrecision), epochs=trainConfig['epochs'], batch_size=trainConfig['batch_size'], project=ConstantUtils.modelBasePath, name=projectName) modelTrainConfig.save() # #2.3 开启训练线程 trainModelTH = trainModelThread(trainDataDict, valDataDict, modelTrainConfig) trainModelTH.start()
def getSingleImageDetectResult(self, serviceSessionId, threshold, imgData): FileNewName = fileUtils.getRandomName(imgData.filename) savedPath = ConstantUtils.singleImgDetectSource + FileNewName loggerUtils.info("图片保存路径:" + savedPath) # 保存图片 imgData.save(savedPath) # 获取当前的模型 detectServiceIns = yoloDetect.getDetectServiceInstance( serviceSessionId) if detectServiceIns is not None: loggerUtils.info("找到相关检测模型...") detectConfig = detectConfigUtils.getBasicDetectConfig( source=savedPath, outPath=ConstantUtils.singleImgDetectOut) detectResult = detectServiceIns.detect(detectConfig) result = { "imagePath": ConstantUtils.imageItemPrefix + "singleImgDetectOut_" + FileNewName, "detectResult": detectResult } return resultPackerUtils.packCusResult(result) else: return resultPackerUtils.packErrorMsg( resultPackerUtils.EC_NO_EVALUATE_SESSION)
def getVideoDetectResult(self, serviceSessionId, threshold, detectVideo): FileNewName = fileUtils.getRandomName(detectVideo.filename) savedPath = ConstantUtils.videoDetectSource + FileNewName loggerUtils.info("视频保存路径:" + savedPath) # 保存图片 detectVideo.save(savedPath) # 获取当前的模型 detectServiceIns = yoloDetect.getDetectServiceInstance( serviceSessionId) if detectServiceIns is not None: loggerUtils.info("找到相关检测模型...") detectConfig = detectConfigUtils.getBasicDetectConfig( source=savedPath, outPath=ConstantUtils.videoDetectOut) detectServiceIns.detect(detectConfig) detectRecordItem = detectRecord( videoSavedPath=ConstantUtils.videoDetectOut + FileNewName) detectRecordItem.save() result = { ConstantUtils.videoPlayUrl: ConstantUtils.videoPlayPrefix + str(detectRecordItem.dereid) } return resultPackerUtils.packCusResult(result) else: return resultPackerUtils.packErrorMsg( resultPackerUtils.EC_NO_EVALUATE_SESSION)
def genFramesFromLiveStream(self, sessionId): loggerUtils.info("开启直播...." + sessionId) yoloDetectThreadMap[sessionId].startBroardcast() while True: yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + yoloDetectThreadMap[sessionId].getStreamQueue().get() + b'\r\n')
def batchInitYoloModel(self): loggerUtils.info("init detect model start...") detectServiceList = detectServiceBean.objects( dtsSwitch=ConstantUtils.SERVICE_SWITCH_ON, state=ConstantUtils.DATA_STATUS_ACTIVE) for item in detectServiceList: if item['dmtvId'] != None: yoloDetectServiceImpl.launchYoloDetectService( sessionId=item['dtsServiceKey'], dmtvid=item['dmtvId']) loggerUtils.info("init detect model finish***")
def trainMobile_nanodet_model(self, jsonData): loggerUtils.info("start lite nanodet model train thread [end]") nanoTrainConfig = { "local_rank": -1, "save_dir": jsonData['projectDir'], "ds_dl_list": jsonData['ds_dl_list'], "cfg": "data/nanodet-self.yml", "imageBasePath": ConstantUtils.dataBasePath, 'ckptModelSavePath': jsonData["ckptModelSavePath"], 'entireModelSavePath': jsonData["entireModelSavePath"] } nanodetThread = nanodetTrainThread(nanoTrainConfig) nanodetThread.start()
def launchYoloDetectService(self, sessionId=None, dmtvid=None, isWatch=False): #如果模型版本为空 直接返回 if dmtvid is None: return None #判断是不是重复加载 if sessionDmtvMap.keys().__contains__(sessionId): if sessionDmtvMap[sessionId] == dmtvid: return # 根据版本ID获取模型地址 dmVersionBean = modelVersionService.getDMVersionBean(dmtvid)[0] modelConfig = {} modelConfig["weights"] = dmVersionBean['ckptModelSavePath'] modelConfig["device"] = '' if modelConfig["weights"] == None: return detectSerIns = detectServiceThread(modelConfig) if sessionId is None: sessionId = randomUtils.getRandomStr() yoloDetectThreadMap[sessionId] = detectSerIns sessionDmtvMap[sessionId] = dmtvid # 把sessionId放入到redis中供监控线程监控 if isWatch: ConstantUtils.updateDetectSessionTime(sessionId) loggerUtils.info("模型启动完毕...." + sessionId) loggerUtils.info("sessions of detectThreadMap:" + str(yoloDetectThreadMap.keys())) resultMap = {ConstantUtils.serviceSessionId: sessionId} return resultPackerUtils.packCusResult(resultMap)
class mongoSource: loggerUtils.info("mongodb init...") mdb=MongoEngine() @classmethod def initMongoDBSource(cls,app): app.config['MONGODB_SETTINGS'] = { 'db': 'admin', 'host': configUtils.getConfigProperties("mongo","mongo_host"), 'port': int(configUtils.getConfigProperties("mongo","mongo_port")), 'username':configUtils.getConfigProperties("mongo","mongo_user"), 'password':configUtils.getConfigProperties("mongo","mongo_pwd") } cls.mdb.init_app(app)
def releaseYoloDetectThread(self, serviceSessionId): if yoloDetectThreadMap.keys().__contains__(serviceSessionId): #先关闭线程 yoloDetectThreadMap[serviceSessionId].stopDetect() #如果线程启动了,才执行join方法 if yoloDetectThreadMap[serviceSessionId].is_alive(): yoloDetectThreadMap[serviceSessionId].join() #destroy the object del yoloDetectThreadMap[serviceSessionId] # remove from map # yoloDetectThreadMap.pop(serviceSessionId) if sessionDmtvMap.keys().__contains__(serviceSessionId): sessionDmtvMap.pop(serviceSessionId) loggerUtils.info("release model:" + str(serviceSessionId)) return True else: loggerUtils.info("sessions not in detectThreadMap:" + str(serviceSessionId)) return True
def loadTrainData(self,dmtvid,ds_dl_list): imagePathList = [] LabelsList = [] imageShapeList = [] dl_id_index_map = {} dlOrderedList = [] dlIndex = 0 for dsItem in ds_dl_list: if dsItem['isSelectAll'] == ConstantUtils.TRUE_TAG: datImageList = dataImageItem.objects(dsId=dsItem["dsId"], state=1) else: datImageList = dataImageItem.objects(dsId=dsItem["dsId"], labelIdList__in=dsItem["dlidList"], state=1) for imageItem in datImageList: reclabelList = imageItem['recLabelList'] # 如果图片有标注数据,才参与训练 if len(reclabelList) > 0: itemLabelList = [] for item in reclabelList: if (dsItem['isSelectAll'] == ConstantUtils.TRUE_TAG or (dsItem['isSelectAll'] == ConstantUtils.FALSE_TAG and dsItem["dlidList"].__contains__(item['dlid']))): if not dl_id_index_map.keys().__contains__(item['dlid']): dl_id_index_map[item['dlid']]=dlIndex dlOrderedList.append(item['dlid']) dlIndex+=1 itemLabelList.append([dl_id_index_map[item['dlid']], item['rec_yolo_x'], item['rec_yolo_y'], item['rec_w'], item['rec_h']]) imagePathList.append(fileUtils.getABSPath(imageItem['ditFilePath'])) imageShapeList.append([imageItem['ditWidth'], imageItem['ditHeight']]) LabelsList.append(np.array(itemLabelList)) print("***************dlid_dlIndex_map**************") print(str(dl_id_index_map)) #将dlid和index的关系保存到trainVersion中 detectModelTrainVersion.objects(dmtvid=dmtvid,state=ConstantUtils.DATA_STATUS_ACTIVE).update(dl_id_index_map=str(dl_id_index_map)) labelMap, nameList = labelService.getLabelsBylids(dsItem["dsId"]) #对nameList进行排序 newnameList=[labelMap[item] for item in dlOrderedList] loggerUtils.info("labelMap:" + str(labelMap)) index = 0 for i in range(imagePathList.__len__()): print("-----------------****" + str(index) + "******-----------------") print(imagePathList[i]) print(LabelsList[i]) print(imageShapeList[i]) index += 1 trainDataDict = { "imagePathList": imagePathList, "LabelsList": np.array(LabelsList), "imageShapeList": np.array(imageShapeList), "nc": newnameList.__len__(), "names": newnameList } valDataDict = { "imagePathList": imagePathList, "LabelsList": np.array(LabelsList), "imageShapeList": np.array(imageShapeList), "nc": newnameList.__len__(), "names": newnameList } return trainDataDict, valDataDict
def updateDetectSessionTime(cls, serviceSessionId): loggerUtils.info("update session:" + serviceSessionId + "/" + str(dateUtils.getTimeStamp())) cls.detectSessionMap[serviceSessionId] = dateUtils.getTimeStamp()
def closeVideoWrite(cls,vw): loggerUtils.info("关闭录制。。。") vw.release()