def __init__( self, plotsLocation = False ): self.plotsLocation = plotsLocation self.alive = True self.__graphCache = DictCache( deleteFunction = _deleteGraph ) self.__graphLifeTime = 600 self.purgeThread = threading.Thread( target = self.purgeExpired ) self.purgeThread.start()
def __init__( self ): self.graphsLocation = os.path.join( gConfig.getValue( '/LocalSite/InstancePath', rootPath ), 'data', 'accountingPlots' ) self.cachedGraphs = {} self.alive = True self.purgeThread = threading.Thread( target = self.purgeExpired ) self.purgeThread.start() self.__dataCache = DictCache() self.__graphCache = DictCache( deleteFunction = _deleteGraph ) self.__dataLifeTime = 600 self.__graphLifeTime = 3600
def initializeOptimizer( cls ): """ Initialization of the Agent. """ random.seed() cls.__SEStatus = DictCache.DictCache() cls.__sitesForSE = DictCache.DictCache() try: from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB except ImportError, excp : return S_ERROR( "Could not import JobDB: %s" % str( excp ) )
def __init__(self): self.graphsLocation = os.path.join( gConfig.getValue('/LocalSite/InstancePath', rootPath), 'data', 'accountingPlots') self.cachedGraphs = {} self.alive = True self.purgeThread = threading.Thread(target=self.purgeExpired) self.purgeThread.start() self.__dataCache = DictCache() self.__graphCache = DictCache(deleteFunction=_deleteGraph) self.__dataLifeTime = 600 self.__graphLifeTime = 3600
class PlotCache: def __init__( self, plotsLocation = False ): self.plotsLocation = plotsLocation self.alive = True self.__graphCache = DictCache( deleteFunction = _deleteGraph ) self.__graphLifeTime = 600 self.purgeThread = threading.Thread( target = self.purgeExpired ) self.purgeThread.start() def setPlotsLocation( self, plotsDir ): self.plotsLocation = plotsDir for plot in os.listdir( self.plotsLocation ): if plot.find( ".png" ) > 0: plotLocation = "%s/%s" % ( self.plotsLocation, plot ) gLogger.verbose( "Purging %s" % plotLocation ) os.unlink( plotLocation ) def purgeExpired( self ): while self.alive: time.sleep( self.__graphLifeTime ) self.__graphCache.purgeExpired() def getPlot( self, plotHash, plotData, plotMetadata, subplotMetadata ): """ Get plot from the cache if exists, else generate it """ plotDict = self.__graphCache.get( plotHash ) if plotDict == False: basePlotFileName = "%s/%s.png" % ( self.plotsLocation, plotHash ) if subplotMetadata: retVal = graph( plotData, basePlotFileName, plotMetadata, metadata = subplotMetadata ) else: retVal = graph( plotData, basePlotFileName, plotMetadata ) if not retVal[ 'OK' ]: return retVal plotDict = retVal[ 'Value' ] if plotDict[ 'plot' ]: plotDict[ 'plot' ] = os.path.basename( basePlotFileName ) self.__graphCache.add( plotHash, self.__graphLifeTime, plotDict ) return S_OK( plotDict ) def getPlotData( self, plotFileName ): filename = "%s/%s" % ( self.plotsLocation, plotFileName ) try: fd = file( filename, "rb" ) data = fd.read() fd.close() except Exception, v: return S_ERROR( "Can't open file %s: %s" % ( plotFileName, str( v ) ) ) return S_OK( data )
def initializeOptimizer(cls): """Initialize specific parameters """ cls.ex_setProperty('shifterProxy', 'DataManager') cls.__SEStatus = DictCache.DictCache() try: cls.__replicaMan = ReplicaManager() except Exception, e: msg = 'Failed to create ReplicaManager' cls.log.exception(msg) return S_ERROR(msg + str(e))
def __init__( self, maxQueueSize = 10 ): random.seed() DB.__init__( self, 'TaskQueueDB', 'WorkloadManagement/TaskQueueDB', maxQueueSize ) self.__multiValueDefFields = ( 'Sites', 'GridCEs', 'GridMiddlewares', 'BannedSites', 'LHCbPlatforms', 'PilotTypes', 'SubmitPools', 'JobTypes' ) self.__multiValueMatchFields = ( 'GridCE', 'Site', 'GridMiddleware', 'LHCbPlatform', 'PilotType', 'SubmitPool', 'JobType' ) self.__bannedJobMatchFields = ( 'Site', ) self.__strictRequireMatchFields = ( 'SubmitPool', 'LHCbPlatform', 'PilotType' ) self.__singleValueDefFields = ( 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ) self.__mandatoryMatchFields = ( 'Setup', 'CPUTime' ) self.__defaultCPUSegments = maxCPUSegments self.__maxMatchRetry = 3 self.__jobPriorityBoundaries = ( 0.001, 10 ) self.__groupShares = {} self.__deleteTQWithDelay = DictCache( self.__deleteTQIfEmpty ) self.__opsHelper = Operations() self.__ensureInsertionIsSingle = False self.__sharesCorrector = SharesCorrector( self.__opsHelper ) result = self.__initializeDB() if not result[ 'OK' ]: raise Exception( "Can't create tables: %s" % result[ 'Message' ] )
def __init__( self, maxQueueSize = 10 ): random.seed() DB.__init__( self, 'TaskQueueDB', 'WorkloadManagement/TaskQueueDB', maxQueueSize ) self.__multiValueDefFields = ( 'Sites', 'GridCEs', 'GridMiddlewares', 'BannedSites', 'Platforms', 'PilotTypes', 'SubmitPools', 'JobTypes' ) self.__multiValueMatchFields = ( 'GridCE', 'Site', 'GridMiddleware', 'Platform', 'PilotType', 'SubmitPool', 'JobType' ) self.__bannedJobMatchFields = ( 'Site', ) self.__strictRequireMatchFields = ( 'SubmitPool', 'Platform', 'PilotType' ) self.__singleValueDefFields = ( 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ) self.__mandatoryMatchFields = ( 'Setup', 'CPUTime' ) self.__defaultCPUSegments = maxCPUSegments self.__maxMatchRetry = 3 self.__jobPriorityBoundaries = ( 0.001, 10 ) self.__groupShares = {} self.__deleteTQWithDelay = DictCache( self.__deleteTQIfEmpty ) self.__opsHelper = Operations() self.__ensureInsertionIsSingle = False self.__sharesCorrector = SharesCorrector( self.__opsHelper ) result = self.__initializeDB() if not result[ 'OK' ]: raise Exception( "Can't create tables: %s" % result[ 'Message' ] )
class AccountingPlotOldHandler(WebHandler): AUTH_PROPS = "authenticated" __keysCache = DictCache.DictCache() def __getUniqueKeyValues(self, typeName): sessionData = SessionData().getData() userGroup = sessionData["user"]["group"] if 'NormalUser' in CS.getPropertiesForGroup(userGroup): cacheKey = (sessionData["user"]["username"], userGroup, sessionData["setup"], typeName) else: cacheKey = (userGroup, sessionData["setup"], typeName) data = AccountingPlotOldHandler.__keysCache.get(cacheKey) if not data: rpcClient = RPCClient("Accounting/ReportGenerator") retVal = rpcClient.listUniqueKeyValues(typeName) if 'rpcStub' in retVal: del (retVal['rpcStub']) if not retVal['OK']: return retVal #Site ordering based on TierLevel / alpha if 'Site' in retVal['Value']: siteLevel = {} for siteName in retVal['Value']['Site']: sitePrefix = siteName.split(".")[0].strip() level = gConfig.getValue( "/Resources/Sites/%s/%s/MoUTierLevel" % (sitePrefix, siteName), 10) if level not in siteLevel: siteLevel[level] = [] siteLevel[level].append(siteName) orderedSites = [] for level in sorted(siteLevel): orderedSites.extend(sorted(siteLevel[level])) retVal['Value']['Site'] = orderedSites data = retVal AccountingPlotOldHandler.__keysCache.add(cacheKey, 300, data) return data def web_getSelectionData(self): callback = {} typeName = self.request.arguments["type"][0] #Get unique key values retVal = self.__getUniqueKeyValues(typeName) if not retVal['OK']: self.write( json.dumps({ "success": "false", "result": "", "error": retVal['Message'] })) return callback["selectionValues"] = json.dumps(retVal['Value']) #Cache for plotsList? data = AccountingPlotOldHandler.__keysCache.get("reportsList:%s" % typeName) if not data: repClient = ReportsClient( rpcClient=RPCClient("Accounting/ReportGenerator")) retVal = repClient.listReports(typeName) if not retVal['OK']: self.write( json.dumps({ "success": "false", "result": "", "error": retVal['Message'] })) return data = json.dumps(retVal['Value']) AccountingPlotOldHandler.__keysCache.add( "reportsList:%s" % typeName, 300, data) callback["plotsList"] = data self.write(json.dumps({"success": "true", "result": callback})) def __parseFormParams(self): params = self.request.arguments pD = {} extraParams = {} pinDates = False for name in params: if name.find("_") != 0: continue value = params[name][0] name = name[1:] pD[name] = str(value) print pD #Personalized title? if 'plotTitle' in pD: extraParams['plotTitle'] = pD['plotTitle'] del (pD['plotTitle']) #Pin dates? if 'pinDates' in pD: pinDates = pD['pinDates'] del (pD['pinDates']) pinDates = pinDates.lower() in ("yes", "y", "true", "1") #Get plotname if not 'grouping' in pD: return S_ERROR("Missing grouping!") grouping = pD['grouping'] #Get plotname if not 'typeName' in pD: return S_ERROR("Missing type name!") typeName = pD['typeName'] del (pD['typeName']) #Get plotname if not 'plotName' in pD: return S_ERROR("Missing plot name!") reportName = pD['plotName'] del (pD['plotName']) #Get times if not 'timeSelector' in pD: return S_ERROR("Missing time span!") #Find the proper time! pD['timeSelector'] = int(pD['timeSelector']) if pD['timeSelector'] > 0: end = Time.dateTime() start = end - datetime.timedelta(seconds=pD['timeSelector']) if not pinDates: extraParams['lastSeconds'] = pD['timeSelector'] else: if 'endTime' not in pD: end = False else: end = Time.fromString(pD['endTime']) del (pD['endTime']) if 'startTime' not in pD: return S_ERROR("Missing starTime!") else: start = Time.fromString(pD['startTime']) del (pD['startTime']) del (pD['timeSelector']) for k in pD: if k.find("ex_") == 0: extraParams[k[3:]] = pD[k] #Listify the rest for selName in pD: pD[selName] = List.fromChar(pD[selName], ",") return S_OK( (typeName, reportName, start, end, pD, grouping, extraParams)) def web_generatePlot(self): callback = {} retVal = self.__queryForPlot() if retVal['OK']: callback = {'success': True, 'data': retVal['Value']['plot']} else: callback = {'success': False, 'errors': retVal['Message']} self.write(json.dumps(callback)) def __queryForPlot(self): retVal = self.__parseFormParams() if not retVal['OK']: return retVal params = retVal['Value'] repClient = ReportsClient( rpcClient=RPCClient("Accounting/ReportGenerator")) retVal = repClient.generateDelayedPlot(*params) return retVal @asyncGen def web_getPlotImg(self): """ Get plot image """ callback = {} if 'file' not in self.request.arguments: callback = { "success": "false", "error": "Maybe you forgot the file?" } self.finish(json.dumps(callback)) return plotImageFile = str(self.request.arguments['file'][0]) if plotImageFile.find(".png") < -1: callback = {"success": "false", "error": "Not a valid image!"} self.finish(json.dumps(callback)) return transferClient = TransferClient("Accounting/ReportGenerator") tempFile = tempfile.TemporaryFile() retVal = yield self.threadTask(transferClient.receiveFile, tempFile, plotImageFile) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(json.dumps(callback)) return tempFile.seek(0) data = tempFile.read() self.set_header('Content-type', 'image/png') self.set_header( 'Content-Disposition', 'attachment; filename="%s.png"' % md5(plotImageFile).hexdigest()) self.set_header('Content-Length', len(data)) self.set_header('Content-Transfer-Encoding', 'Binary') self.set_header('Cache-Control', "no-cache, no-store, must-revalidate, max-age=0") self.set_header('Pragma', "no-cache") self.set_header('Expires', ( datetime.datetime.utcnow() - datetime.timedelta(minutes=-10)).strftime("%d %b %Y %H:%M:%S GMT")) self.finish(data) @asyncGen def web_getCsvPlotData(self): callback = {} retVal = self.__parseFormParams() if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) params = retVal['Value'] repClient = ReportsClient( rpcClient=RPCClient("Accounting/ReportGenerator")) retVal = yield self.threadTask(repClient.getReport, *params) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) rawData = retVal['Value'] groupKeys = rawData['data'].keys() groupKeys.sort() # print rawData['data'] if 'granularity' in rawData: granularity = rawData['granularity'] data = rawData['data'] tS = int(Time.toEpoch(params[2])) timeStart = tS - tS % granularity strData = "epoch,%s\n" % ",".join(groupKeys) for timeSlot in range(timeStart, int(Time.toEpoch(params[3])), granularity): lineData = [str(timeSlot)] for key in groupKeys: if timeSlot in data[key]: lineData.append(str(data[key][timeSlot])) else: lineData.append("") strData += "%s\n" % ",".join(lineData) else: strData = "%s\n" % ",".join(groupKeys) strData += ",".join([str(rawData['data'][k]) for k in groupKeys]) self.set_header('Content-type', 'text/csv') self.set_header( 'Content-Disposition', 'attachment; filename="%s.csv"' % md5(str(params)).hexdigest()) self.set_header('Content-Length', len(strData)) self.finish(strData) @asyncGen def web_getPlotData(self): callback = {} retVal = self.__parseFormParams() if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) params = retVal['Value'] repClient = ReportsClient( rpcClient=RPCClient("Accounting/ReportGenerator")) retVal = yield self.threadTask(repClient.getReport, *params) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) rawData = retVal['Value'] groupKeys = rawData['data'].keys() groupKeys.sort() self.finish(json.dumps(rawData['data']))
class AccountingPlotHandler(WebHandler): AUTH_PROPS = "authenticated" __keysCache = DictCache() def __getUniqueKeyValues(self, typeName): sessionData = SessionData().getData() userGroup = sessionData["user"]["group"] if 'NormalUser' in CS.getPropertiesForGroup(userGroup): cacheKey = (sessionData["user"]["username"], userGroup, sessionData["setup"], typeName) else: cacheKey = (userGroup, sessionData["setup"], typeName) data = AccountingPlotHandler.__keysCache.get(cacheKey) if not data: rpcClient = RPCClient("Accounting/ReportGenerator") retVal = rpcClient.listUniqueKeyValues(typeName) if 'rpcStub' in retVal: del (retVal['rpcStub']) if not retVal['OK']: return retVal #Site ordering based on TierLevel / alpha if 'Site' in retVal['Value']: siteLevel = {} for siteName in retVal['Value']['Site']: sitePrefix = siteName.split(".")[0].strip() level = gConfig.getValue( "/Resources/Sites/%s/%s/MoUTierLevel" % (sitePrefix, siteName), 10) if level not in siteLevel: siteLevel[level] = [] siteLevel[level].append(siteName) orderedSites = [] for level in sorted(siteLevel): orderedSites.extend(sorted(siteLevel[level])) retVal['Value']['Site'] = orderedSites data = retVal AccountingPlotHandler.__keysCache.add(cacheKey, 300, data) return data def web_getSelectionData(self): callback = {} typeName = self.request.arguments["type"][0] #Get unique key values retVal = self.__getUniqueKeyValues(typeName) if not retVal['OK']: self.write( json.dumps({ "success": "false", "result": "", "error": retVal['Message'] })) return callback["selectionValues"] = simplejson.dumps(retVal['Value']) #Cache for plotsList? data = AccountingPlotHandler.__keysCache.get("reportsList:%s" % typeName) if not data: repClient = ReportsClient( rpcClient=RPCClient("Accounting/ReportGenerator")) retVal = repClient.listReports(typeName) if not retVal['OK']: self.write( json.dumps({ "success": "false", "result": "", "error": retVal['Message'] })) return data = simplejson.dumps(retVal['Value']) AccountingPlotHandler.__keysCache.add("reportsList:%s" % typeName, 300, data) callback["plotsList"] = data self.write(json.dumps({"success": "true", "result": callback}))
class DataCache: def __init__( self ): self.graphsLocation = os.path.join( gConfig.getValue( '/LocalSite/InstancePath', rootPath ), 'data', 'accountingPlots' ) self.cachedGraphs = {} self.alive = True self.purgeThread = threading.Thread( target = self.purgeExpired ) self.purgeThread.start() self.__dataCache = DictCache() self.__graphCache = DictCache( deleteFunction = _deleteGraph ) self.__dataLifeTime = 600 self.__graphLifeTime = 3600 def setGraphsLocation( self, graphsDir ): self.graphsLocation = graphsDir for graphName in os.listdir( self.graphsLocation ): if graphName.find( ".png" ) > 0: graphLocation = "%s/%s" % ( self.graphsLocation, graphName ) gLogger.verbose( "Purging %s" % graphLocation ) os.unlink( graphLocation ) def purgeExpired( self ): while self.alive: time.sleep( 600 ) self.__graphCache.purgeExpired() self.__dataCache.purgeExpired() def getReportData( self, reportRequest, reportHash, dataFunc ): """ Get report data from cache if exists, else generate it """ reportData = self.__dataCache.get( reportHash ) if reportData == False: retVal = dataFunc( reportRequest ) if not retVal[ 'OK' ]: return retVal reportData = retVal[ 'Value' ] self.__dataCache.add( reportHash, self.__dataLifeTime, reportData ) return S_OK( reportData ) def getReportPlot( self, reportRequest, reportHash, reportData, plotFunc ): """ Get report data from cache if exists, else generate it """ plotDict = self.__graphCache.get( reportHash ) if plotDict == False: basePlotFileName = "%s/%s" % ( self.graphsLocation, reportHash ) retVal = plotFunc( reportRequest, reportData, basePlotFileName ) if not retVal[ 'OK' ]: return retVal plotDict = retVal[ 'Value' ] if plotDict[ 'plot' ]: plotDict[ 'plot' ] = "%s.png" % reportHash if plotDict[ 'thumbnail' ]: plotDict[ 'thumbnail' ] = "%s.thb.png" % reportHash self.__graphCache.add( reportHash, self.__graphLifeTime, plotDict ) return S_OK( plotDict ) def getPlotData( self, plotFileName ): filename = "%s/%s" % ( self.graphsLocation, plotFileName ) try: fd = file( filename, "rb" ) data = fd.read() fd.close() except Exception, e: return S_ERROR( "Can't open file %s: %s" % ( plotFileName, str( e ) ) ) return S_OK( data )
class JobMonitorHandler(WebHandler): AUTH_PROPS = "authenticated" __dataCache = DictCache.DictCache() @asyncGen def web_getJobData(self): RPC = RPCClient("WorkloadManagement/JobMonitoring", timeout=600) req = self._request() result = yield self.threadTask(RPC.getJobPageSummaryWeb, req, self.globalSort, self.pageNumber, self.numberOfJobs) if not result["OK"]: self.finish({ "success": "false", "result": [], "total": 0, "error": result["Message"] }) return result = result["Value"] if not result.has_key("TotalRecords"): self.finish({ "success": "false", "result": [], "total": -1, "error": "Data structure is corrupted" }) return if not (result["TotalRecords"] > 0): self.finish({ "success": "false", "result": [], "total": 0, "error": "There were no data matching your selection" }) return if not (result.has_key("ParameterNames") and result.has_key("Records")): self.finish({ "success": "false", "result": [], "total": -1, "error": "Data structure is corrupted" }) return if not (len(result["ParameterNames"]) > 0): self.finish({ "success": "false", "result": [], "total": -1, "error": "ParameterNames field is missing" }) return if not (len(result["Records"]) > 0): self.finish({ "success": "false", "result": [], "total": 0, "Message": "There are no data to display" }) return callback = [] jobs = result["Records"] head = result["ParameterNames"] headLength = len(head) for i in jobs: tmp = {} for j in range(0, headLength): tmp[head[j]] = i[j] callback.append(tmp) total = result["TotalRecords"] if result.has_key("Extras"): st = self.__dict2string({}) extra = result["Extras"] timestamp = Time.dateTime().strftime("%Y-%m-%d %H:%M [UTC]") callback = { "success": "true", "result": callback, "total": total, "extra": extra, "request": st, "date": timestamp } else: callback = { "success": "true", "result": callback, "total": total, "date": None } self.finish(callback) def __dict2string(self, req): result = "" try: for key, value in req.iteritems(): result = result + str(key) + ": " + ", ".join(value) + "; " except Exception, x: pass gLogger.info("\033[0;31m Exception: \033[0m %s" % x) result = result.strip() result = result[:-1] return result
class DataCache: def __init__(self): self.graphsLocation = os.path.join( gConfig.getValue('/LocalSite/InstancePath', rootPath), 'data', 'accountingPlots') self.cachedGraphs = {} self.alive = True self.purgeThread = threading.Thread(target=self.purgeExpired) self.purgeThread.setDaemon(1) self.purgeThread.start() self.__dataCache = DictCache() self.__graphCache = DictCache(deleteFunction=self._deleteGraph) self.__dataLifeTime = 600 self.__graphLifeTime = 3600 def setGraphsLocation(self, graphsDir): self.graphsLocation = graphsDir for graphName in os.listdir(self.graphsLocation): if graphName.find(".png") > 0: graphLocation = "%s/%s" % (self.graphsLocation, graphName) gLogger.verbose("Purging %s" % graphLocation) os.unlink(graphLocation) def purgeExpired(self): while self.alive: time.sleep(600) self.__graphCache.purgeExpired() self.__dataCache.purgeExpired() def getReportData(self, reportRequest, reportHash, dataFunc): """ Get report data from cache if exists, else generate it """ reportData = self.__dataCache.get(reportHash) if reportData == False: retVal = dataFunc(reportRequest) if not retVal['OK']: return retVal reportData = retVal['Value'] self.__dataCache.add(reportHash, self.__dataLifeTime, reportData) return S_OK(reportData) def getReportPlot(self, reportRequest, reportHash, reportData, plotFunc): """ Get report data from cache if exists, else generate it """ plotDict = self.__graphCache.get(reportHash) if plotDict == False: basePlotFileName = "%s/%s" % (self.graphsLocation, reportHash) retVal = plotFunc(reportRequest, reportData, basePlotFileName) if not retVal['OK']: return retVal plotDict = retVal['Value'] if plotDict['plot']: plotDict['plot'] = "%s.png" % reportHash if plotDict['thumbnail']: plotDict['thumbnail'] = "%s.thb.png" % reportHash self.__graphCache.add(reportHash, self.__graphLifeTime, plotDict) return S_OK(plotDict) def getPlotData(self, plotFileName): filename = "%s/%s" % (self.graphsLocation, plotFileName) try: fd = file(filename, "rb") data = fd.read() fd.close() except Exception, e: return S_ERROR("Can't open file %s: %s" % (plotFileName, str(e))) return S_OK(data)
class MonitoringHandler(WebHandler): AUTH_PROPS = "authenticated" __keysCache = DictCache.DictCache() def __getUniqueKeyValues(self, typeName): sessionData = self.getSessionData() cacheKey = (sessionData["user"].get("username", ""), sessionData["user"].get("group", ""), sessionData["setup"], typeName) data = MonitoringHandler.__keysCache.get(cacheKey) if not data: client = MonitoringClient() retVal = client.listUniqueKeyValues(typeName) if 'rpcStub' in retVal: del(retVal['rpcStub']) if not retVal['OK']: return retVal # Site ordering based on TierLevel / alpha if 'Site' in retVal['Value']: siteLevel = {} for siteName in retVal['Value']['Site']: sitePrefix = siteName.split(".")[0].strip() level = gConfig.getValue("/Resources/Sites/%s/%s/MoUTierLevel" % (sitePrefix, siteName), 10) if level not in siteLevel: siteLevel[level] = [] siteLevel[level].append(siteName) orderedSites = [] for level in sorted(siteLevel): orderedSites.extend(sorted(siteLevel[level])) retVal['Value']['Site'] = orderedSites data = retVal MonitoringHandler.__keysCache.add(cacheKey, 300, data) return data @asyncGen def web_getSelectionData(self): callback = {} typeName = self.request.arguments["type"][0] # Get unique key values retVal = yield self.threadTask(self.__getUniqueKeyValues, typeName) if not retVal['OK']: self.finish({"success": "false", "result": "", "error": retVal['Message']}) return records = {} for record in retVal['Value']: # may have more than 1000 of records. # do not show all of them in the web portal length = len(retVal['Value'][record]) if length > 10000: records[record] = retVal['Value'][record][length - 5000:] message = "The %s accounting type contains to many rows: %s - > %d. Note: Only 1000 rows are returned!" % ( typeName, record, length) gLogger.warn(message) else: records[record] = retVal['Value'][record] callback["selectionValues"] = records # Cache for plotsList? data = MonitoringHandler.__keysCache.get("reportsList:%s" % typeName) if not data: repClient = MonitoringClient() retVal = yield self.threadTask(repClient.listReports, typeName) if not retVal['OK']: self.finish({"success": "false", "result": "", "error": retVal['Message']}) return data = retVal['Value'] MonitoringHandler.__keysCache.add("reportsList:%s" % typeName, 300, data) callback["plotsList"] = data self.finish({"success": "true", "result": callback}) def __parseFormParams(self): params = self.request.arguments pD = {} extraParams = {} pinDates = False for name in params: if name.find("_") != 0: continue value = params[name][0] name = name[1:] pD[name] = str(value) # Personalized title? if 'plotTitle' in pD: extraParams['plotTitle'] = pD['plotTitle'] del(pD['plotTitle']) # Pin dates? if 'pinDates' in pD: pinDates = pD['pinDates'] del(pD['pinDates']) pinDates = pinDates.lower() in ("yes", "y", "true", "1") # Get plotname if 'grouping' not in pD: return S_ERROR("Missing grouping!") grouping = pD['grouping'] # Get plotname if 'typeName' not in pD: return S_ERROR("Missing type name!") typeName = pD['typeName'] del(pD['typeName']) # Get plotname if 'plotName' not in pD: return S_ERROR("Missing plot name!") reportName = pD['plotName'] del(pD['plotName']) # Get times if 'timeSelector' not in pD: return S_ERROR("Missing time span!") # Find the proper time! pD['timeSelector'] = int(pD['timeSelector']) if pD['timeSelector'] > 0: end = Time.dateTime() start = end - datetime.timedelta(seconds=pD['timeSelector']) if not pinDates: extraParams['lastSeconds'] = pD['timeSelector'] else: if 'endTime' not in pD: end = False else: end = Time.fromString(pD['endTime']) del(pD['endTime']) if 'startTime' not in pD: return S_ERROR("Missing starTime!") else: start = Time.fromString(pD['startTime']) del(pD['startTime']) del(pD['timeSelector']) for k in pD: if k.find("ex_") == 0: extraParams[k[3:]] = pD[k] # Listify the rest for selName in pD: if selName == 'grouping': pD[selName] = [pD[selName]] else: try: pD[selName] = json.loads(pD[selName]) except ValueError: pD[selName] = List.fromChar(pD[selName], ",") return S_OK((typeName, reportName, start, end, pD, grouping, extraParams)) @asyncGen def web_generatePlot(self): callback = {} retVal = yield self.threadTask(self.__queryForPlot) if retVal['OK']: callback = {'success': True, 'data': retVal['Value']['plot']} else: callback = {'success': False, 'errors': retVal['Message']} self.finish(callback) def __queryForPlot(self): retVal = self.__parseFormParams() if not retVal['OK']: return retVal params = retVal['Value'] repClient = MonitoringClient(rpcClient=RPCClient("Monitoring/Monitoring")) retVal = repClient.generateDelayedPlot(*params) return retVal @asyncGen def web_getPlotImg(self): """ Get plot image """ callback = {} if 'file' not in self.request.arguments: callback = {"success": "false", "error": "Maybe you forgot the file?"} self.finish(callback) return plotImageFile = str(self.request.arguments['file'][0]) if plotImageFile.find(".png") < -1: callback = {"success": "false", "error": "Not a valid image!"} self.finish(callback) return transferClient = TransferClient("Monitoring/Monitoring") tempFile = tempfile.TemporaryFile() retVal = yield self.threadTask(transferClient.receiveFile, tempFile, plotImageFile) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) return tempFile.seek(0) data = tempFile.read() self.set_header('Content-type', 'image/png') self.set_header('Content-Disposition', 'attachment; filename="%s.png"' % md5(plotImageFile).hexdigest()) self.set_header('Content-Length', len(data)) self.set_header('Content-Transfer-Encoding', 'Binary') #self.set_header( 'Cache-Control', "no-cache, no-store, must-revalidate, max-age=0" ) #self.set_header( 'Pragma', "no-cache" ) #self.set_header( 'Expires', ( datetime.datetime.utcnow() - datetime.timedelta( minutes = -10 ) ).strftime( "%d %b %Y %H:%M:%S GMT" ) ) self.finish(data) @asyncGen def web_getPlotImgFromCache(self): """ Get plot image from cache. """ callback = {} if 'file' not in self.request.arguments: callback = {"success": "false", "error": "Maybe you forgot the file?"} self.finish(callback) return plotImageFile = str(self.request.arguments['file'][0]) retVal = extractRequestFromFileId(plotImageFile) if not retVal['OK']: callback = {"success": "false", "error": retVal['Value']} self.finish(callback) return fields = retVal['Value'] if "extraArgs" in fields: # in order to get the plot from the cache we have to clean the extraArgs... plotTitle = "" if 'plotTitle' in fields["extraArgs"]: plotTitle = fields["extraArgs"]["plotTitle"] fields["extraArgs"] = {} fields["extraArgs"]["plotTitle"] = plotTitle else: fields["extraArgs"] = {} retVal = codeRequestInFileId(fields) if not retVal['OK']: callback = {"success": "false", "error": retVal['Value']} self.finish(callback) return plotImageFile = retVal['Value']['plot'] transferClient = TransferClient("Monitoring/Monitoring") tempFile = tempfile.TemporaryFile() retVal = yield self.threadTask(transferClient.receiveFile, tempFile, plotImageFile) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) return tempFile.seek(0) data = tempFile.read() self.set_header('Content-type', 'image/png') self.set_header('Content-Disposition', 'attachment; filename="%s.png"' % md5(plotImageFile).hexdigest()) self.set_header('Content-Length', len(data)) self.set_header('Content-Transfer-Encoding', 'Binary') self.set_header('Cache-Control', "no-cache, no-store, must-revalidate, max-age=0") self.set_header('Pragma', "no-cache") self.set_header( 'Expires', (datetime.datetime.utcnow() - datetime.timedelta(minutes=-10)).strftime("%d %b %Y %H:%M:%S GMT")) self.finish(data) @asyncGen def web_getCsvPlotData(self): callback = {} retVal = self.__parseFormParams() if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) params = retVal['Value'] repClient = MonitoringClient(rpcClient=RPCClient("Monitoring/Monitoring")) retVal = yield self.threadTask(repClient.getReport, *params) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) rawData = retVal['Value'] groupKeys = rawData['data'].keys() groupKeys.sort() # print rawData['data'] if 'granularity' in rawData: granularity = rawData['granularity'] data = rawData['data'] tS = int(Time.toEpoch(params[2])) timeStart = tS - tS % granularity strData = "epoch,%s\n" % ",".join(groupKeys) for timeSlot in range(timeStart, int(Time.toEpoch(params[3])), granularity): lineData = [str(timeSlot)] for key in groupKeys: if timeSlot in data[key]: lineData.append(str(data[key][timeSlot])) else: lineData.append("") strData += "%s\n" % ",".join(lineData) else: strData = "%s\n" % ",".join(groupKeys) strData += ",".join([str(rawData['data'][k]) for k in groupKeys]) self.set_header('Content-type', 'text/csv') self.set_header('Content-Disposition', 'attachment; filename="%s.csv"' % md5(str(params)).hexdigest()) self.set_header('Content-Length', len(strData)) self.finish(strData) @asyncGen def web_getPlotData(self): callback = {} retVal = self.__parseFormParams() if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) params = retVal['Value'] repClient = MonitoringClient(rpcClient=RPCClient("Monitoring/Monitoring")) retVal = yield self.threadTask(repClient.getReport, *params) if not retVal['OK']: callback = {"success": "false", "error": retVal['Message']} self.finish(callback) rawData = retVal['Value'] self.finish(rawData['data'])
class AccountingplotsController(BaseController): __keysCache = DictCache() def __getUniqueKeyValues(self, typeName): userGroup = getSelectedGroup() if 'NormalUser' in CS.getPropertiesForGroup(userGroup): cacheKey = (getUserName(), userGroup, getSelectedSetup(), typeName) else: cacheKey = (userGroup, getSelectedSetup(), typeName) data = AccountingplotsController.__keysCache.get(cacheKey) if not data: rpcClient = getRPCClient("Accounting/ReportGenerator") retVal = rpcClient.listUniqueKeyValues(typeName) if 'rpcStub' in retVal: del (retVal['rpcStub']) if not retVal['OK']: return retVal #Site ordering based on TierLevel / alpha if 'Site' in retVal['Value']: siteLevel = {} for siteName in retVal['Value']['Site']: sitePrefix = siteName.split(".")[0].strip() level = gConfig.getValue( "/Resources/Sites/%s/%s/MoUTierLevel" % (sitePrefix, siteName), 10) if level not in siteLevel: siteLevel[level] = [] siteLevel[level].append(siteName) orderedSites = [] for level in sorted(siteLevel): orderedSites.extend(sorted(siteLevel[level])) retVal['Value']['Site'] = orderedSites data = retVal AccountingplotsController.__keysCache.add(cacheKey, 300, data) return data def index(self): # Return a rendered template # return render('/some/template.mako') # or, Return a response return defaultRedirect() def dataOperation(self): return self.__showPlotPage("DataOperation", "/systems/accounting/dataOperation.mako") def job(self): return self.__showPlotPage("Job", "/systems/accounting/job.mako") def WMSHistory(self): return self.__showPlotPage("WMSHistory", "/systems/accounting/WMSHistory.mako") def pilot(self): return self.__showPlotPage("Pilot", "/systems/accounting/Pilot.mako") def SRMSpaceTokenDeployment(self): return self.__showPlotPage( "SRMSpaceTokenDeployment", "/systems/accounting/SRMSpaceTokenDeployment.mako") def plotPage(self): try: typeName = str(request.params['typeName']) except: c.errorMessage = "Oops. missing type" return render("/error.mako") return self.__showPlotPage(typeName, "/systems/accounting/%s.mako" % typeName) def __showPlotPage(self, typeName, templateFile): #Get unique key values retVal = self.__getUniqueKeyValues(typeName) if not retVal['OK']: c.error = retVal['Message'] return render("/error.mako") c.selectionValues = simplejson.dumps(retVal['Value']) #Cache for plotsList? data = AccountingplotsController.__keysCache.get("reportsList:%s" % typeName) if not data: repClient = ReportsClient( rpcClient=getRPCClient("Accounting/ReportGenerator")) retVal = repClient.listReports(typeName) if not retVal['OK']: c.error = retVal['Message'] return render("/error.mako") data = simplejson.dumps(retVal['Value']) AccountingplotsController.__keysCache.add( "reportsList:%s" % typeName, 300, data) c.plotsList = data return render(templateFile) @jsonify def getKeyValuesForType(self): try: typeName = str(request.params['typeName']) except: return S_ERROR("Missing or invalid type name!") retVal = self.__getUniqueKeyValues(typeName) if not retVal['OK'] and 'rpcStub' in retVal: del (retVal['rpcStub']) return retVal def __parseFormParams(self): pD = {} extraParams = {} pinDates = False for name in request.params: if name.find("_") != 0: continue value = request.params[name] name = name[1:] pD[name] = str(value) #Personalized title? if 'plotTitle' in pD: extraParams['plotTitle'] = pD['plotTitle'] del (pD['plotTitle']) #Pin dates? if 'pinDates' in pD: pinDates = pD['pinDates'] del (pD['pinDates']) pinDates = pinDates.lower() in ("yes", "y", "true", "1") #Get plotname if not 'grouping' in pD: return S_ERROR("Missing grouping!") grouping = pD['grouping'] #Get plotname if not 'typeName' in pD: return S_ERROR("Missing type name!") typeName = pD['typeName'] del (pD['typeName']) #Get plotname if not 'plotName' in pD: return S_ERROR("Missing plot name!") reportName = pD['plotName'] del (pD['plotName']) #Get times if not 'timeSelector' in pD: return S_ERROR("Missing time span!") #Find the proper time! pD['timeSelector'] = int(pD['timeSelector']) if pD['timeSelector'] > 0: end = Time.dateTime() start = end - datetime.timedelta(seconds=pD['timeSelector']) if not pinDates: extraParams['lastSeconds'] = pD['timeSelector'] else: if 'endTime' not in pD: end = False else: end = Time.fromString(pD['endTime']) del (pD['endTime']) if 'startTime' not in pD: return S_ERROR("Missing starTime!") else: start = Time.fromString(pD['startTime']) del (pD['startTime']) del (pD['timeSelector']) for k in pD: if k.find("ex_") == 0: extraParams[k[3:]] = pD[k] #Listify the rest for selName in pD: pD[selName] = List.fromChar(pD[selName], ",") return S_OK( (typeName, reportName, start, end, pD, grouping, extraParams)) def __translateToExpectedExtResult(self, retVal): if retVal['OK']: return {'success': True, 'data': retVal['Value']['plot']} else: return {'success': False, 'errors': retVal['Message']} def __queryForPlot(self): retVal = self.__parseFormParams() if not retVal['OK']: return retVal params = retVal['Value'] repClient = ReportsClient( rpcClient=getRPCClient("Accounting/ReportGenerator")) retVal = repClient.generateDelayedPlot(*params) return retVal def getPlotData(self): retVal = self.__parseFormParams() if not retVal['OK']: c.error = retVal['Message'] return render("/error.mako") params = retVal['Value'] repClient = ReportsClient( rpcClient=getRPCClient("Accounting/ReportGenerator")) retVal = repClient.getReport(*params) if not retVal['OK']: c.error = retVal['Message'] return render("/error.mako") rawData = retVal['Value'] groupKeys = rawData['data'].keys() groupKeys.sort() if 'granularity' in rawData: granularity = rawData['granularity'] data = rawData['data'] tS = int(Time.toEpoch(params[2])) timeStart = tS - tS % granularity strData = "epoch,%s\n" % ",".join(groupKeys) for timeSlot in range(timeStart, int(Time.toEpoch(params[3])), granularity): lineData = [str(timeSlot)] for key in groupKeys: if timeSlot in data[key]: lineData.append(str(data[key][timeSlot])) else: lineData.append("") strData += "%s\n" % ",".join(lineData) else: strData = "%s\n" % ",".join(groupKeys) strData += ",".join([str(rawData['data'][k]) for k in groupKeys]) response.headers['Content-type'] = 'text/csv' response.headers[ 'Content-Disposition'] = 'attachment; filename="%s.csv"' % md5( str(params)).hexdigest() response.headers['Content-Length'] = len(strData) return strData @jsonify def generatePlot(self): return self.__translateToExpectedExtResult(self.__queryForPlot()) def generatePlotAndGetHTML(self): retVal = self.__queryForPlot() if not retVal['OK']: return "<h2>Can't regenerate plot: %s</h2>" % retVal['Message'] return "<img src='getPlotImg?file=%s'/>" % retVal['Value']['plot'] def getPlotImg(self): """ Get plot image """ if 'file' not in request.params: c.error = "Maybe you forgot the file?" return render("/error.mako") plotImageFile = str(request.params['file']) if plotImageFile.find(".png") < -1: c.error = "Not a valid image!" return render("/error.mako") transferClient = getTransferClient("Accounting/ReportGenerator") tempFile = tempfile.TemporaryFile() retVal = transferClient.receiveFile(tempFile, plotImageFile) if not retVal['OK']: c.error = retVal['Message'] return render("/error.mako") tempFile.seek(0) data = tempFile.read() response.headers['Content-type'] = 'image/png' response.headers[ 'Content-Disposition'] = 'attachment; filename="%s.png"' % md5( plotImageFile).hexdigest() response.headers['Content-Length'] = len(data) response.headers['Content-Transfer-Encoding'] = 'Binary' response.headers[ 'Cache-Control'] = "no-cache, no-store, must-revalidate, max-age=0" response.headers['Pragma'] = "no-cache" response.headers['Expires'] = ( datetime.datetime.utcnow() - datetime.timedelta(minutes=-10)).strftime("%d %b %Y %H:%M:%S GMT") return data
class ProxyManagerClient: def __init__( self ): self.__usersCache = DictCache() self.__proxiesCache = DictCache() self.__vomsProxiesCache = DictCache() self.__pilotProxiesCache = DictCache() self.__filesCache = DictCache( self.__deleteTemporalFile ) def __deleteTemporalFile( self, filename ): try: os.unlink( filename ) except: pass def clearCaches( self ): self.__usersCache.purgeAll() self.__proxiesCache.purgeAll() self.__vomsProxiesCache.purgeAll() self.__pilotProxiesCache.purgeAll() def __getSecondsLeftToExpiration( self, expiration, utc = True ): if utc: td = expiration - datetime.datetime.utcnow() else: td = expiration - datetime.datetime.now() return td.days * 86400 + td.seconds def __refreshUserCache( self, validSeconds = 0 ): rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) retVal = rpcClient.getRegisteredUsers( validSeconds ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] #Update the cache for record in data: cacheKey = ( record[ 'DN' ], record[ 'group' ] ) self.__usersCache.add( cacheKey, self.__getSecondsLeftToExpiration( record[ 'expirationtime' ] ), record ) return S_OK() @gUsersSync def userHasProxy( self, userDN, userGroup, validSeconds = 0 ): """ Check if a user(DN-group) has a proxy in the proxy management - Updates internal cache if needed to minimize queries to the service """ cacheKey = ( userDN, userGroup ) if self.__usersCache.exists( cacheKey, validSeconds ): return S_OK( True ) #Get list of users from the DB with proxys at least 300 seconds gLogger.verbose( "Updating list of users in proxy management" ) retVal = self.__refreshUserCache( validSeconds ) if not retVal[ 'OK' ]: return retVal return S_OK( self.__usersCache.exists( cacheKey, validSeconds ) ) @gUsersSync def getUserPersistence( self, userDN, userGroup, validSeconds = 0 ): """ Check if a user(DN-group) has a proxy in the proxy management - Updates internal cache if needed to minimize queries to the service """ cacheKey = ( userDN, userGroup ) userData = self.__usersCache.get( cacheKey, validSeconds ) if userData: if userData[ 'persistent' ]: return S_OK( True ) #Get list of users from the DB with proxys at least 300 seconds gLogger.verbose( "Updating list of users in proxy management" ) retVal = self.__refreshUserCache( validSeconds ) if not retVal[ 'OK' ]: return retVal userData = self.__usersCache.get( cacheKey, validSeconds ) if userData: return S_OK( userData[ 'persistent' ] ) return S_OK( False ) def setPersistency( self, userDN, userGroup, persistent ): """ Set the persistency for user/group """ #Hack to ensure bool in the rpc call persistentFlag = True if not persistent: persistentFlag = False rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) retVal = rpcClient.setPersistency( userDN, userGroup, persistentFlag ) if not retVal[ 'OK' ]: return retVal #Update internal persistency cache cacheKey = ( userDN, userGroup ) record = self.__usersCache.get( cacheKey, 0 ) if record: record[ 'persistent' ] = persistentFlag self.__usersCache.add( cacheKey, self.__getSecondsLeftToExpiration( record[ 'expirationtime' ] ), record ) return retVal def uploadProxy( self, proxy = False, diracGroup = False, chainToConnect = False, restrictLifeTime = 0 ): """ Upload a proxy to the proxy management service using delgation """ #Discover proxy location if type( proxy ) == g_X509ChainType: chain = proxy proxyLocation = "" else: if not proxy: proxyLocation = Locations.getProxyLocation() if not proxyLocation: return S_ERROR( "Can't find a valid proxy" ) elif type( proxy ) in ( types.StringType, types.UnicodeType ): proxyLocation = proxy else: return S_ERROR( "Can't find a valid proxy" ) chain = X509Chain() retVal = chain.loadProxyFromFile( proxyLocation ) if not retVal[ 'OK' ]: return S_ERROR( "Can't load %s: %s " % ( proxyLocation, retVal[ 'Message' ] ) ) if not chainToConnect: chainToConnect = chain #Make sure it's valid if chain.hasExpired()[ 'Value' ]: return S_ERROR( "Proxy %s has expired" % proxyLocation ) #rpcClient = RPCClient( "Framework/ProxyManager", proxyChain = chainToConnect ) rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) #Get a delegation request retVal = rpcClient.requestDelegationUpload( chain.getRemainingSecs()['Value'], diracGroup ) if not retVal[ 'OK' ]: return retVal #Check if the delegation has been granted if 'Value' not in retVal or not retVal[ 'Value' ]: return S_OK() reqDict = retVal[ 'Value' ] #Generate delegated chain chainLifeTime = chain.getRemainingSecs()[ 'Value' ] - 60 if restrictLifeTime and restrictLifeTime < chainLifeTime: chainLifeTime = restrictLifeTime retVal = chain.generateChainFromRequestString( reqDict[ 'request' ], lifetime = chainLifeTime, diracGroup = diracGroup ) if not retVal[ 'OK' ]: return retVal #Upload! return rpcClient.completeDelegationUpload( reqDict[ 'id' ], retVal[ 'Value' ] ) @gProxiesSync def downloadProxy( self, userDN, userGroup, limited = False, requiredTimeLeft = 43200, proxyToConnect = False, token = False ): """ Get a proxy Chain from the proxy management """ cacheKey = ( userDN, userGroup ) if self.__proxiesCache.exists( cacheKey, requiredTimeLeft ): return S_OK( self.__proxiesCache.get( cacheKey ) ) req = X509Request() req.generateProxyRequest( limited = limited ) if proxyToConnect: rpcClient = RPCClient( "Framework/ProxyManager", proxyChain = proxyToConnect, timeout = 120 ) else: rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) if token: retVal = rpcClient.getProxyWithToken( userDN, userGroup, req.dumpRequest()['Value'], long( requiredTimeLeft ), token ) else: retVal = rpcClient.getProxy( userDN, userGroup, req.dumpRequest()['Value'], long( requiredTimeLeft ) ) if not retVal[ 'OK' ]: return retVal chain = X509Chain( keyObj = req.getPKey() ) retVal = chain.loadChainFromString( retVal[ 'Value' ] ) if not retVal[ 'OK' ]: return retVal self.__proxiesCache.add( cacheKey, chain.getRemainingSecs()['Value'], chain ) return S_OK( chain ) def downloadProxyToFile( self, userDN, userGroup, limited = False, requiredTimeLeft = 43200, filePath = False, proxyToConnect = False, token = False ): """ Get a proxy Chain from the proxy management and write it to file """ retVal = self.downloadProxy( userDN, userGroup, limited, requiredTimeLeft, proxyToConnect, token ) if not retVal[ 'OK' ]: return retVal chain = retVal[ 'Value' ] retVal = self.dumpProxyToFile( chain, filePath ) if not retVal[ 'OK' ]: return retVal retVal[ 'chain' ] = chain return retVal @gVOMSProxiesSync def downloadVOMSProxy( self, userDN, userGroup, limited = False, requiredTimeLeft = 43200, requiredVOMSAttribute = False, proxyToConnect = False, token = False ): """ Download a proxy if needed and transform it into a VOMS one """ cacheKey = ( userDN, userGroup, requiredVOMSAttribute, limited ) if self.__vomsProxiesCache.exists( cacheKey, requiredTimeLeft ): return S_OK( self.__vomsProxiesCache.get( cacheKey ) ) req = X509Request() req.generateProxyRequest( limited = limited ) if proxyToConnect: rpcClient = RPCClient( "Framework/ProxyManager", proxyChain = proxyToConnect, timeout = 120 ) else: rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) if token: retVal = rpcClient.getVOMSProxyWithToken( userDN, userGroup, req.dumpRequest()['Value'], long( requiredTimeLeft ), token, requiredVOMSAttribute ) else: retVal = rpcClient.getVOMSProxy( userDN, userGroup, req.dumpRequest()['Value'], long( requiredTimeLeft ), requiredVOMSAttribute ) if not retVal[ 'OK' ]: return retVal chain = X509Chain( keyObj = req.getPKey() ) retVal = chain.loadChainFromString( retVal[ 'Value' ] ) if not retVal[ 'OK' ]: return retVal self.__vomsProxiesCache.add( cacheKey, chain.getRemainingSecs()['Value'], chain ) return S_OK( chain ) def downloadVOMSProxyToFile( self, userDN, userGroup, limited = False, requiredTimeLeft = 43200, requiredVOMSAttribute = False, filePath = False, proxyToConnect = False, token = False ): """ Download a proxy if needed, transform it into a VOMS one and write it to file """ retVal = self.downloadVOMSProxy( userDN, userGroup, limited, requiredTimeLeft, requiredVOMSAttribute, proxyToConnect, token ) if not retVal[ 'OK' ]: return retVal chain = retVal[ 'Value' ] retVal = self.dumpProxyToFile( chain, filePath ) if not retVal[ 'OK' ]: return retVal retVal[ 'chain' ] = chain return retVal def getPilotProxyFromDIRACGroup( self, userDN, userGroup, requiredTimeLeft = 43200, proxyToConnect = False ): """ Download a pilot proxy with VOMS extensions depending on the group """ #Assign VOMS attribute vomsAttr = CS.getVOMSAttributeForGroup( userGroup ) if not vomsAttr: gLogger.verbose( "No voms attribute assigned to group %s when requested pilot proxy" % userGroup ) return self.downloadProxy( userDN, userGroup, limited = False, requiredTimeLeft = requiredTimeLeft, proxyToConnect = proxyToConnect ) else: return self.downloadVOMSProxy( userDN, userGroup, limited = False, requiredTimeLeft = requiredTimeLeft, requiredVOMSAttribute = vomsAttr, proxyToConnect = proxyToConnect ) def getPilotProxyFromVOMSGroup( self, userDN, vomsAttr, requiredTimeLeft = 43200, proxyToConnect = False ): """ Download a pilot proxy with VOMS extensions depending on the group """ groups = CS.getGroupsWithVOMSAttribute( vomsAttr ) if not groups: return S_ERROR( "No group found that has %s as voms attrs" % vomsAttr ) for userGroup in groups: result = self.downloadVOMSProxy( userDN, userGroup, limited = False, requiredTimeLeft = requiredTimeLeft, requiredVOMSAttribute = vomsAttr, proxyToConnect = proxyToConnect ) if result['OK']: return result return result def getPayloadProxyFromDIRACGroup( self, userDN, userGroup, requiredTimeLeft, token = False, proxyToConnect = False ): """ Download a payload proxy with VOMS extensions depending on the group """ #Assign VOMS attribute vomsAttr = CS.getVOMSAttributeForGroup( userGroup ) if not vomsAttr: gLogger.verbose( "No voms attribute assigned to group %s when requested payload proxy" % userGroup ) return self.downloadProxy( userDN, userGroup, limited = True, requiredTimeLeft = requiredTimeLeft, proxyToConnect = proxyToConnect, token = token ) else: return self.downloadVOMSProxy( userDN, userGroup, limited = True, requiredTimeLeft = requiredTimeLeft, requiredVOMSAttribute = vomsAttr, proxyToConnect = proxyToConnect, token = token ) def getPayloadProxyFromVOMSGroup( self, userDN, vomsAttr, token, requiredTimeLeft, proxyToConnect = False ): """ Download a payload proxy with VOMS extensions depending on the VOMS attr """ groups = CS.getGroupsWithVOMSAttribute( vomsAttr ) if not groups: return S_ERROR( "No group found that has %s as voms attrs" % vomsAttr ) userGroup = groups[0] return self.downloadVOMSProxy( userDN, userGroup, limited = True, requiredTimeLeft = requiredTimeLeft, requiredVOMSAttribute = vomsAttr, proxyToConnect = proxyToConnect, token = token ) def dumpProxyToFile( self, chain, destinationFile = False, requiredTimeLeft = 600 ): """ Dump a proxy to a file. It's cached so multiple calls won't generate extra files """ if self.__filesCache.exists( chain, requiredTimeLeft ): filepath = self.__filesCache.get( chain ) if os.path.isfile( filepath ): return S_OK( filepath ) self.__filesCache.delete( filepath ) retVal = chain.dumpAllToFile( destinationFile ) if not retVal[ 'OK' ]: return retVal filename = retVal[ 'Value' ] self.__filesCache.add( chain, chain.getRemainingSecs()['Value'], filename ) return S_OK( filename ) def deleteGeneratedProxyFile( self, chain ): """ Delete a file generated by a dump """ self.__filesCache.delete( chain ) return S_OK() def requestToken( self, requesterDN, requesterGroup, numUses = 1 ): """ Request a number of tokens. usesList must be a list of integers and each integer is the number of uses a token must have """ rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) return rpcClient.generateToken( requesterDN, requesterGroup, numUses ) def renewProxy( self, proxyToBeRenewed = False, minLifeTime = 3600, newProxyLifeTime = 43200, proxyToConnect = False ): """ Renew a proxy using the ProxyManager Arguments: proxyToBeRenewed : proxy to renew minLifeTime : if proxy life time is less than this, renew. Skip otherwise newProxyLifeTime : life time of new proxy proxyToConnect : proxy to use for connecting to the service """ retVal = File.multiProxyArgument( proxyToBeRenewed ) if not retVal[ 'Value' ]: return retVal proxyToRenewDict = retVal[ 'Value' ] secs = proxyToRenewDict[ 'chain' ].getRemainingSecs()[ 'Value' ] if secs > minLifeTime: File.deleteMultiProxy( proxyToRenewDict ) return S_OK() if not proxyToConnect: proxyToConnectDict = proxyToRenewDict else: retVal = File.multiProxyArgument( proxyToConnect ) if not retVal[ 'Value' ]: File.deleteMultiProxy( proxyToRenewDict ) return retVal proxyToConnectDict = retVal[ 'Value' ] userDN = proxyToRenewDict[ 'chain' ].getIssuerCert()[ 'Value' ].getSubjectDN()[ 'Value' ] retVal = proxyToRenewDict[ 'chain' ].getDIRACGroup() if not retVal[ 'OK' ]: File.deleteMultiProxy( proxyToRenewDict ) File.deleteMultiProxy( proxyToConnectDict ) return retVal userGroup = retVal[ 'Value' ] limited = proxyToRenewDict[ 'chain' ].isLimitedProxy()[ 'Value' ] voms = VOMS() retVal = voms.getVOMSAttributes( proxyToRenewDict[ 'chain' ] ) if not retVal[ 'OK' ]: File.deleteMultiProxy( proxyToRenewDict ) File.deleteMultiProxy( proxyToConnectDict ) return retVal vomsAttrs = retVal[ 'Value' ] if vomsAttrs: retVal = self.downloadVOMSProxy( userDN, userGroup, limited = limited, requiredTimeLeft = newProxyLifeTime, requiredVOMSAttribute = vomsAttrs[0], proxyToConnect = proxyToConnectDict[ 'chain' ] ) else: retVal = self.downloadProxy( userDN, userGroup, limited = limited, requiredTimeLeft = newProxyLifeTime, proxyToConnect = proxyToConnectDict[ 'chain' ] ) File.deleteMultiProxy( proxyToRenewDict ) File.deleteMultiProxy( proxyToConnectDict ) if not retVal[ 'OK' ]: return retVal if not proxyToRenewDict[ 'tempFile' ]: return proxyToRenewDict[ 'chain' ].dumpAllToFile( proxyToRenewDict[ 'file' ] ) return S_OK( proxyToRenewDict[ 'chain' ] ) def getDBContents( self, condDict = {} ): """ Get the contents of the db """ rpcClient = RPCClient( "Framework/ProxyManager", timeout = 120 ) return rpcClient.getContents( condDict, [ [ 'UserDN', 'DESC' ] ], 0, 0 ) def getVOMSAttributes( self, chain ): """ Get the voms attributes for a chain """ return VOMS().getVOMSAttributes( chain ) def getUploadedProxyLifeTime( self, DN, group ): """ Get the remaining seconds for an uploaded proxy """ result = self.getDBContents( { 'UserDN' : [ DN ], 'UserGroup' : [ group ] } ) if not result[ 'OK' ]: return result data = result[ 'Value' ] if len( data[ 'Records' ] ) == 0: return S_OK( 0 ) pNames = list( data[ 'ParameterNames' ] ) dnPos = pNames.index( 'UserDN' ) groupPos = pNames.index( 'UserGroup' ) expiryPos = pNames.index( 'ExpirationTime' ) for row in data[ 'Records' ]: if DN == row[ dnPos ] and group == row[ groupPos ]: td = row[ expiryPos ] - datetime.datetime.utcnow() secondsLeft = td.days * 86400 + td.seconds return S_OK( max( 0, secondsLeft ) ) return S_OK( 0 )
def __init__( self ): self.__usersCache = DictCache() self.__proxiesCache = DictCache() self.__vomsProxiesCache = DictCache() self.__pilotProxiesCache = DictCache() self.__filesCache = DictCache( self.__deleteTemporalFile )
class TaskQueueDB( DB ): def __init__( self, maxQueueSize = 10 ): random.seed() DB.__init__( self, 'TaskQueueDB', 'WorkloadManagement/TaskQueueDB', maxQueueSize ) self.__multiValueDefFields = ( 'Sites', 'GridCEs', 'GridMiddlewares', 'BannedSites', 'LHCbPlatforms', 'PilotTypes', 'SubmitPools', 'JobTypes' ) self.__multiValueMatchFields = ( 'GridCE', 'Site', 'GridMiddleware', 'LHCbPlatform', 'PilotType', 'SubmitPool', 'JobType' ) self.__bannedJobMatchFields = ( 'Site', ) self.__strictRequireMatchFields = ( 'SubmitPool', 'LHCbPlatform', 'PilotType' ) self.__singleValueDefFields = ( 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ) self.__mandatoryMatchFields = ( 'Setup', 'CPUTime' ) self.__defaultCPUSegments = maxCPUSegments self.__maxMatchRetry = 3 self.__jobPriorityBoundaries = ( 0.001, 10 ) self.__groupShares = {} self.__deleteTQWithDelay = DictCache( self.__deleteTQIfEmpty ) self.__opsHelper = Operations() self.__ensureInsertionIsSingle = False self.__sharesCorrector = SharesCorrector( self.__opsHelper ) result = self.__initializeDB() if not result[ 'OK' ]: raise Exception( "Can't create tables: %s" % result[ 'Message' ] ) def enableAllTaskQueues( self ): """ Enable all Task queues """ return self.updateFields( "tq_TaskQueues", updateDict = { "Enabled" :"1" } ) def findOrphanJobs( self ): """ Find jobs that are not in any task queue """ return self._query( "select JobID from tq_Jobs WHERE TQId not in (SELECT TQId from tq_TaskQueues)" ) def isSharesCorrectionEnabled( self ): return self.__getCSOption( "EnableSharesCorrection", False ) def getSingleValueTQDefFields( self ): return self.__singleValueDefFields def getMultiValueTQDefFields( self ): return self.__multiValueDefFields def getMultiValueMatchFields( self ): return self.__multiValueMatchFields def __getCSOption( self, optionName, defValue ): return self.__opsHelper.getValue( "Matching/%s" % optionName, defValue ) def getPrivatePilots( self ): return self.__getCSOption( "PrivatePilotTypes", [ 'private' ] ) def getValidPilotTypes( self ): return self.__getCSOption( "AllPilotTypes", [ 'private' ] ) def __initializeDB( self ): """ Create the tables """ result = self._query( "show tables" ) if not result[ 'OK' ]: return result tablesInDB = [ t[0] for t in result[ 'Value' ] ] tablesToCreate = {} self.__tablesDesc = {} self.__tablesDesc[ 'tq_TaskQueues' ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED AUTO_INCREMENT NOT NULL', 'OwnerDN' : 'VARCHAR(255) NOT NULL', 'OwnerGroup' : 'VARCHAR(32) NOT NULL', 'Setup' : 'VARCHAR(32) NOT NULL', 'CPUTime' : 'BIGINT UNSIGNED NOT NULL', 'Priority' : 'FLOAT NOT NULL', 'Enabled' : 'TINYINT(1) NOT NULL DEFAULT 0' }, 'PrimaryKey' : 'TQId', 'Indexes': { 'TQOwner': [ 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ] } } self.__tablesDesc[ 'tq_Jobs' ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED NOT NULL', 'JobId' : 'INTEGER UNSIGNED NOT NULL', 'Priority' : 'INTEGER UNSIGNED NOT NULL', 'RealPriority' : 'FLOAT NOT NULL' }, 'PrimaryKey' : 'JobId', 'Indexes': { 'TaskIndex': [ 'TQId' ] }, } for multiField in self.__multiValueDefFields: tableName = 'tq_TQTo%s' % multiField self.__tablesDesc[ tableName ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED NOT NULL', 'Value' : 'VARCHAR(64) NOT NULL', }, 'Indexes': { 'TaskIndex': [ 'TQId' ], '%sIndex' % multiField: [ 'Value' ] }, } for tableName in self.__tablesDesc: if not tableName in tablesInDB: tablesToCreate[ tableName ] = self.__tablesDesc[ tableName ] return self._createTables( tablesToCreate ) def getGroupsInTQs( self ): cmdSQL = "SELECT DISTINCT( OwnerGroup ) FROM `tq_TaskQueues`" result = self._query( cmdSQL ) if not result[ 'OK' ]: return result return S_OK( [ row[0] for row in result[ 'Value' ] ] ) def forceRecreationOfTables( self ): dropSQL = "DROP TABLE IF EXISTS %s" % ", ".join( self.__tablesDesc ) result = self._update( dropSQL ) if not result[ 'OK' ]: return result return self._createTables( self.__tablesDesc ) def __strDict( self, dDict ): lines = [] for key in sorted( dDict ): lines.append( " %s" % key ) value = dDict[ key ] if type( value ) in ( types.ListType, types.TupleType ): lines.extend( [ " %s" % v for v in value ] ) else: lines.append( " %s" % str( value ) ) return "{\n%s\n}" % "\n".join( lines ) def fitCPUTimeToSegments( self, cpuTime ): """ Fit the CPU time to the valid segments """ maxCPUSegments = self.__getCSOption( "taskQueueCPUTimeIntervals", self.__defaultCPUSegments ) try: maxCPUSegments = [ int( seg ) for seg in maxCPUSegments ] #Check segments in the CS last = 0 for cpuS in maxCPUSegments: if cpuS <= last: maxCPUSegments = self.__defaultCPUSegments break last = cpuS except: maxCPUSegments = self.__defaultCPUSegments #Map to a segment for iP in range( len( maxCPUSegments ) ): cpuSegment = maxCPUSegments[ iP ] if cpuTime <= cpuSegment: return cpuSegment return maxCPUSegments[-1] def _checkTaskQueueDefinition( self, tqDefDict ): """ Check a task queue definition dict is valid """ for field in self.__singleValueDefFields: if field not in tqDefDict: return S_ERROR( "Missing mandatory field '%s' in task queue definition" % field ) fieldValueType = type( tqDefDict[ field ] ) if field in [ "CPUTime" ]: if fieldValueType not in ( types.IntType, types.LongType ): return S_ERROR( "Mandatory field %s value type is not valid: %s" % ( field, fieldValueType ) ) else: if fieldValueType not in ( types.StringType, types.UnicodeType ): return S_ERROR( "Mandatory field %s value type is not valid: %s" % ( field, fieldValueType ) ) result = self._escapeString( tqDefDict[ field ] ) if not result[ 'OK' ]: return result tqDefDict[ field ] = result[ 'Value' ] for field in self.__multiValueDefFields: if field not in tqDefDict: continue fieldValueType = type( tqDefDict[ field ] ) if fieldValueType not in ( types.ListType, types.TupleType ): return S_ERROR( "Multi value field %s value type is not valid: %s" % ( field, fieldValueType ) ) result = self._escapeValues( tqDefDict[ field ] ) if not result[ 'OK' ]: return result tqDefDict[ field ] = result[ 'Value' ] #FIXME: This is not used if 'PrivatePilots' in tqDefDict: validPilotTypes = self.getValidPilotTypes() for pilotType in tqDefDict[ 'PrivatePilots' ]: if pilotType not in validPilotTypes: return S_ERROR( "PilotType %s is invalid" % pilotType ) return S_OK( tqDefDict ) def _checkMatchDefinition( self, tqMatchDict ): """ Check a task queue match dict is valid """ def travelAndCheckType( value, validTypes, escapeValues = True ): valueType = type( value ) if valueType in ( types.ListType, types.TupleType ): for subValue in value: subValueType = type( subValue ) if subValueType not in validTypes: return S_ERROR( "List contained type %s is not valid -> %s" % ( subValueType, validTypes ) ) if escapeValues: return self._escapeValues( value ) return S_OK( value ) else: if valueType not in validTypes: return S_ERROR( "Type %s is not valid -> %s" % ( valueType, validTypes ) ) if escapeValues: return self._escapeString( value ) return S_OK( value ) for field in self.__singleValueDefFields: if field not in tqMatchDict: if field in self.__mandatoryMatchFields: return S_ERROR( "Missing mandatory field '%s' in match request definition" % field ) continue fieldValue = tqMatchDict[ field ] if field in [ "CPUTime" ]: result = travelAndCheckType( fieldValue, ( types.IntType, types.LongType ), escapeValues = False ) else: result = travelAndCheckType( fieldValue, ( types.StringType, types.UnicodeType ) ) if not result[ 'OK' ]: return S_ERROR( "Match definition field %s failed : %s" % ( field, result[ 'Message' ] ) ) tqMatchDict[ field ] = result[ 'Value' ] #Check multivalue for multiField in self.__multiValueMatchFields: for field in ( multiField, "Banned%s" % multiField ): if field in tqMatchDict: fieldValue = tqMatchDict[ field ] result = travelAndCheckType( fieldValue, ( types.StringType, types.UnicodeType ) ) if not result[ 'OK' ]: return S_ERROR( "Match definition field %s failed : %s" % ( field, result[ 'Message' ] ) ) tqMatchDict[ field ] = result[ 'Value' ] return S_OK( tqMatchDict ) def __createTaskQueue( self, tqDefDict, priority = 1, connObj = False ): """ Create a task queue Returns S_OK( tqId ) / S_ERROR """ if not connObj: result = self._getConnection() if not result[ 'OK' ]: return S_ERROR( "Can't create task queue: %s" % result[ 'Message' ] ) connObj = result[ 'Value' ] tqDefDict[ 'CPUTime' ] = self.fitCPUTimeToSegments( tqDefDict[ 'CPUTime' ] ) sqlSingleFields = [ 'TQId', 'Priority' ] sqlValues = [ "0", str( priority ) ] for field in self.__singleValueDefFields: sqlSingleFields.append( field ) sqlValues.append( tqDefDict[ field ] ) #Insert the TQ Disabled sqlSingleFields.append( "Enabled" ) sqlValues.append( "0" ) cmd = "INSERT INTO tq_TaskQueues ( %s ) VALUES ( %s )" % ( ", ".join( sqlSingleFields ), ", ".join( [ str( v ) for v in sqlValues ] ) ) result = self._update( cmd, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Can't insert TQ in DB", result[ 'Value' ] ) return result if 'lastRowId' in result: tqId = result['lastRowId'] else: result = self._query( "SELECT LAST_INSERT_ID()", conn = connObj ) if not result[ 'OK' ]: self.cleanOrphanedTaskQueues( connObj = connObj ) return S_ERROR( "Can't determine task queue id after insertion" ) tqId = result[ 'Value' ][0][0] for field in self.__multiValueDefFields: if field not in tqDefDict: continue values = List.uniqueElements( [ value for value in tqDefDict[ field ] if value.strip() ] ) if not values: continue cmd = "INSERT INTO `tq_TQTo%s` ( TQId, Value ) VALUES " % field cmd += ", ".join( [ "( %s, %s )" % ( tqId, str( value ) ) for value in values ] ) result = self._update( cmd, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Failed to insert %s condition" % field, result[ 'Message' ] ) self.cleanOrphanedTaskQueues( connObj = connObj ) return S_ERROR( "Can't insert values %s for field %s: %s" % ( str( values ), field, result[ 'Message' ] ) ) self.log.info( "Created TQ %s" % tqId ) return S_OK( tqId ) def cleanOrphanedTaskQueues( self, connObj = False ): """ Delete all empty task queues """ self.log.info( "Cleaning orphaned TQs" ) result = self._update( "DELETE FROM `tq_TaskQueues` WHERE Enabled >= 1 AND TQId not in ( SELECT DISTINCT TQId from `tq_Jobs` )", conn = connObj ) if not result[ 'OK' ]: return result for mvField in self.__multiValueDefFields: result = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId not in ( SELECT DISTINCT TQId from `tq_TaskQueues` )" % mvField, conn = connObj ) if not result[ 'OK' ]: return result return S_OK() def setTaskQueueState( self, tqId, enabled = True, connObj = False ): if enabled: enabled = "+ 1" else: enabled = "- 1" upSQL = "UPDATE `tq_TaskQueues` SET Enabled = Enabled %s WHERE TQId=%d" % ( enabled, tqId ) result = self._update( upSQL, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Error setting TQ state", "TQ %s State %s: %s" % ( tqId, enabled, result[ 'Message' ] ) ) return result updated = result['Value'] > 0 if updated: self.log.info( "Set enabled = %s for TQ %s" % ( enabled, tqId ) ) return S_OK( updated ) def __hackJobPriority( self, jobPriority ): jobPriority = min( max( int( jobPriority ), self.__jobPriorityBoundaries[0] ), self.__jobPriorityBoundaries[1] ) if jobPriority == self.__jobPriorityBoundaries[0]: return 10 ** ( -5 ) if jobPriority == self.__jobPriorityBoundaries[1]: return 10 ** 6 return jobPriority def insertJob( self, jobId, tqDefDict, jobPriority, skipTQDefCheck = False, numRetries = 10 ): """ Insert a job in a task queue Returns S_OK( tqId ) / S_ERROR """ try: test = long( jobId ) except: return S_ERROR( "JobId is not a number!" ) retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not skipTQDefCheck: tqDefDict = dict( tqDefDict ) retVal = self._checkTaskQueueDefinition( tqDefDict ) if not retVal[ 'OK' ]: self.log.error( "TQ definition check failed", retVal[ 'Message' ] ) return retVal tqDefDict = retVal[ 'Value' ] tqDefDict[ 'CPUTime' ] = self.fitCPUTimeToSegments( tqDefDict[ 'CPUTime' ] ) self.log.info( "Inserting job %s with requirements: %s" % ( jobId, self.__strDict( tqDefDict ) ) ) retVal = self.__findAndDisableTaskQueue( tqDefDict, skipDefinitionCheck = True, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqInfo = retVal[ 'Value' ] newTQ = False if not tqInfo[ 'found' ]: self.log.info( "Creating a TQ for job %s" % jobId ) retVal = self.__createTaskQueue( tqDefDict, 1, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqId = retVal[ 'Value' ] newTQ = True else: tqId = tqInfo[ 'tqId' ] self.log.info( "Found TQ %s for job %s requirements" % ( tqId, jobId ) ) try: result = self.__insertJobInTaskQueue( jobId, tqId, int( jobPriority ), checkTQExists = False, connObj = connObj ) if not result[ 'OK' ]: self.log.error( "Error inserting job in TQ", "Job %s TQ %s: %s" % ( jobId, tqId, result[ 'Message' ] ) ) return result if newTQ: self.recalculateTQSharesForEntity( tqDefDict[ 'OwnerDN' ], tqDefDict[ 'OwnerGroup' ], connObj = connObj ) finally: self.setTaskQueueState( tqId, True ) return S_OK() def __insertJobInTaskQueue( self, jobId, tqId, jobPriority, checkTQExists = True, connObj = False ): """ Insert a job in a given task queue """ self.log.info( "Inserting job %s in TQ %s with priority %s" % ( jobId, tqId, jobPriority ) ) if not connObj: result = self._getConnection() if not result[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % result[ 'Message' ] ) connObj = result[ 'Value' ] if checkTQExists: result = self._query( "SELECT tqId FROM `tq_TaskQueues` WHERE TQId = %s" % tqId, conn = connObj ) if not result[ 'OK' ] or len ( result[ 'Value' ] ) == 0: return S_OK( "Can't find task queue with id %s: %s" % ( tqId, result[ 'Message' ] ) ) hackedPriority = self.__hackJobPriority( jobPriority ) result = self._update( "INSERT INTO tq_Jobs ( TQId, JobId, Priority, RealPriority ) VALUES ( %s, %s, %s, %f )" % ( tqId, jobId, jobPriority, hackedPriority ), conn = connObj ) if not result[ 'OK' ] and result[ 'Message' ].find( "Duplicate entry" ) == -1: return result return S_OK() def __generateTQFindSQL( self, tqDefDict, skipDefinitionCheck = False, connObj = False ): """ Find a task queue that has exactly the same requirements """ if not skipDefinitionCheck: tqDefDict = dict( tqDefDict ) result = self._checkTaskQueueDefinition( tqDefDict ) if not result[ 'OK' ]: return result tqDefDict = result[ 'Value' ] sqlCondList = [] for field in self.__singleValueDefFields: sqlCondList.append( "`tq_TaskQueues`.%s = %s" % ( field, tqDefDict[ field ] ) ) #MAGIC SUBQUERIES TO ENSURE STRICT MATCH for field in self.__multiValueDefFields: tableName = '`tq_TQTo%s`' % field if field in tqDefDict and tqDefDict[ field ]: firstQuery = "SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = `tq_TaskQueues`.TQId" % ( tableName, tableName, tableName ) grouping = "GROUP BY %s.TQId" % tableName valuesList = List.uniqueElements( [ value.strip() for value in tqDefDict[ field ] if value.strip() ] ) numValues = len( valuesList ) secondQuery = "%s AND %s.Value in (%s)" % ( firstQuery, tableName, ",".join( [ "%s" % str( value ) for value in valuesList ] ) ) sqlCondList.append( "%s = (%s %s)" % ( numValues, firstQuery, grouping ) ) sqlCondList.append( "%s = (%s %s)" % ( numValues, secondQuery, grouping ) ) else: sqlCondList.append( "`tq_TaskQueues`.TQId not in ( SELECT DISTINCT %s.TQId from %s )" % ( tableName, tableName ) ) #END MAGIC: That was easy ;) return S_OK( " AND ".join( sqlCondList ) ) def __findAndDisableTaskQueue( self, tqDefDict, skipDefinitionCheck = False, retries = 10, connObj = False ): """ Disable and find TQ """ for i in range( retries ): result = self.findTaskQueue( tqDefDict, skipDefinitionCheck = skipDefinitionCheck, connObj = connObj ) if not result[ 'OK' ]: return result data = result[ 'Value' ] if not data[ 'found' ]: return result result = self._update( "UPDATE `tq_TaskQueues` SET Enabled = Enabled - 1 WHERE TQId = %d" % data[ 'tqId' ] ) if not result[ 'OK' ]: return result if result[ 'Value' ] > 0: return S_OK( data ) return S_ERROR( "Could not disable TQ" ) def findTaskQueue( self, tqDefDict, skipDefinitionCheck = False, connObj = False ): """ Find a task queue that has exactly the same requirements """ result = self.__generateTQFindSQL( tqDefDict, skipDefinitionCheck = skipDefinitionCheck, connObj = connObj ) if not result[ 'OK' ]: return result sqlCmd = "SELECT `tq_TaskQueues`.TQId FROM `tq_TaskQueues` WHERE" sqlCmd = "%s %s" % ( sqlCmd, result[ 'Value' ] ) result = self._query( sqlCmd, conn = connObj ) if not result[ 'OK' ]: return S_ERROR( "Can't find task queue: %s" % result[ 'Message' ] ) data = result[ 'Value' ] if len( data ) == 0: return S_OK( { 'found' : False } ) if len( data ) > 1: gLogger.warn( "Found two task queues for the same requirements", self.__strDict( tqDefDict ) ) return S_OK( { 'found' : True, 'tqId' : data[0][0] } ) def matchAndGetJob( self, tqMatchDict, numJobsPerTry = 50, numQueuesPerTry = 10, negativeCond = {} ): """ Match a job """ #Make a copy to avoid modification of original if escaping needs to be done tqMatchDict = dict( tqMatchDict ) self.log.info( "Starting match for requirements", self.__strDict( tqMatchDict ) ) retVal = self._checkMatchDefinition( tqMatchDict ) if not retVal[ 'OK' ]: self.log.error( "TQ match request check failed", retVal[ 'Message' ] ) return retVal retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't connect to DB: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] preJobSQL = "SELECT `tq_Jobs`.JobId, `tq_Jobs`.TQId FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s AND `tq_Jobs`.Priority = %s" prioSQL = "SELECT `tq_Jobs`.Priority FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s ORDER BY RAND() / `tq_Jobs`.RealPriority ASC LIMIT 1" postJobSQL = " ORDER BY `tq_Jobs`.JobId ASC LIMIT %s" % numJobsPerTry for matchTry in range( self.__maxMatchRetry ): if 'JobID' in tqMatchDict: # A certain JobID is required by the resource, so all TQ are to be considered retVal = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = 0, skipMatchDictDef = True, connObj = connObj ) preJobSQL = "%s AND `tq_Jobs`.JobId = %s " % ( preJobSQL, tqMatchDict['JobID'] ) else: retVal = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = numQueuesPerTry, skipMatchDictDef = True, negativeCond = negativeCond, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqList = retVal[ 'Value' ] if len( tqList ) == 0: self.log.info( "No TQ matches requirements" ) return S_OK( { 'matchFound' : False, 'tqMatch' : tqMatchDict } ) for tqId, tqOwnerDN, tqOwnerGroup in tqList: self.log.info( "Trying to extract jobs from TQ %s" % tqId ) retVal = self._query( prioSQL % tqId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve winning priority for matching job: %s" % retVal[ 'Message' ] ) if len( retVal[ 'Value' ] ) == 0: continue prio = retVal[ 'Value' ][0][0] retVal = self._query( "%s %s" % ( preJobSQL % ( tqId, prio ), postJobSQL ), conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Can't begin transaction for matching job: %s" % retVal[ 'Message' ] ) jobTQList = [ ( row[0], row[1] ) for row in retVal[ 'Value' ] ] if len( jobTQList ) == 0: gLogger.info( "Task queue %s seems to be empty, triggering a cleaning" % tqId ) self.__deleteTQWithDelay.add( tqId, 300, ( tqId, tqOwnerDN, tqOwnerGroup ) ) while len( jobTQList ) > 0: jobId, tqId = jobTQList.pop( random.randint( 0, len( jobTQList ) - 1 ) ) self.log.info( "Trying to extract job %s from TQ %s" % ( jobId, tqId ) ) retVal = self.deleteJob( jobId, connObj = connObj ) if not retVal[ 'OK' ]: msgFix = "Could not take job" msgVar = " %s out from the TQ %s: %s" % ( jobId, tqId, retVal[ 'Message' ] ) self.log.error( msgFix, msgVar ) return S_ERROR( msgFix + msgVar ) if retVal[ 'Value' ] == True : self.log.info( "Extracted job %s with prio %s from TQ %s" % ( jobId, prio, tqId ) ) return S_OK( { 'matchFound' : True, 'jobId' : jobId, 'taskQueueId' : tqId, 'tqMatch' : tqMatchDict } ) self.log.info( "No jobs could be extracted from TQ %s" % tqId ) self.log.info( "Could not find a match after %s match retries" % self.__maxMatchRetry ) return S_ERROR( "Could not find a match after %s match retries" % self.__maxMatchRetry ) def matchAndGetTaskQueue( self, tqMatchDict, numQueuesToGet = 1, skipMatchDictDef = False, negativeCond = {}, connObj = False ): """ Get a queue that matches the requirements """ #Make a copy to avoid modification of original if escaping needs to be done tqMatchDict = dict( tqMatchDict ) if not skipMatchDictDef: retVal = self._checkMatchDefinition( tqMatchDict ) if not retVal[ 'OK' ]: return retVal retVal = self.__generateTQMatchSQL( tqMatchDict, numQueuesToGet = numQueuesToGet, negativeCond = negativeCond ) if not retVal[ 'OK' ]: return retVal matchSQL = retVal[ 'Value' ] retVal = self._query( matchSQL, conn = connObj ) if not retVal[ 'OK' ]: return retVal return S_OK( [ ( row[0], row[1], row[2] ) for row in retVal[ 'Value' ] ] ) def __generateSQLSubCond( self, sqlString, value, boolOp = 'OR' ): if type( value ) not in ( types.ListType, types.TupleType ): return sqlString % str( value ).strip() sqlORList = [] for v in value: sqlORList.append( sqlString % str( v ).strip() ) return "( %s )" % ( " %s " % boolOp ).join( sqlORList ) def __generateNotSQL( self, tableDict, negativeCond ): """ Generate negative conditions Can be a list of dicts or a dict: - list of dicts will be OR of conditional dicts - dicts will be normal conditional dict ( kay1 in ( v1, v2, ... ) AND key2 in ( v3, v4, ... ) ) """ condType = type( negativeCond ) if condType in ( types.ListType, types.TupleType ): sqlCond = [] for cD in negativeCond: sqlCond.append( self.__generateNotDictSQL( tableDict, cD ) ) return " ( %s )" % " OR ".join( sqlCond ) elif condType == types.DictType: return self.__generateNotDictSQL( tableDict, negativeCond ) raise RuntimeError( "negativeCond has to be either a list or a dict and it's %s" % condType ) def __generateNotDictSQL( self, tableDict, negativeCond ): """ Generate the negative sql condition from a standard condition dict """ condList = [] for field in negativeCond: if field in self.__multiValueMatchFields: fullTableN = '`tq_TQTo%ss`' % field valList = negativeCond[ field ] if type( valList ) not in ( types.TupleType, types.ListType ): valList = ( valList, ) for value in valList: value = self._escapeString( value )[ 'Value' ] sql = "%s NOT IN ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( value, fullTableN, fullTableN, fullTableN ) condList.append( sql ) elif field in self.__singleValueDefFields: for value in negativeCond[field]: value = self._escapeString( value )[ 'Value' ] sql = "%s != tq.%s " % ( value, field ) condList.append( sql ) return "( %s )" % " AND ".join( condList ) def __generateTablesName( self, sqlTables, field ): fullTableName = 'tq_TQTo%ss' % field if fullTableName not in sqlTables: tableN = field.lower() sqlTables[ fullTableName ] = tableN return tableN, "`%s`" % fullTableName, return sqlTables[ fullTableName ], "`%s`" % fullTableName def __generateTQMatchSQL( self, tqMatchDict, numQueuesToGet = 1, negativeCond = {} ): """ Generate the SQL needed to match a task queue """ #Only enabled TQs #sqlCondList = [ "Enabled >= 1" ] sqlCondList = [] sqlTables = { "tq_TaskQueues" : "tq" } #If OwnerDN and OwnerGroup are defined only use those combinations that make sense if 'OwnerDN' in tqMatchDict and 'OwnerGroup' in tqMatchDict: groups = tqMatchDict[ 'OwnerGroup' ] if type( groups ) not in ( types.ListType, types.TupleType ): groups = [ groups ] dns = tqMatchDict[ 'OwnerDN' ] if type( dns ) not in ( types.ListType, types.TupleType ): dns = [ dns ] ownerConds = [] for group in groups: if Properties.JOB_SHARING in CS.getPropertiesForGroup( group.replace( '"', "" ) ): ownerConds.append( "tq.OwnerGroup = %s" % group ) else: for dn in dns: ownerConds.append( "( tq.OwnerDN = %s AND tq.OwnerGroup = %s )" % ( dn, group ) ) sqlCondList.append( " OR ".join( ownerConds ) ) else: #If not both are defined, just add the ones that are defined for field in ( 'OwnerGroup', 'OwnerDN' ): if field in tqMatchDict: sqlCondList.append( self.__generateSQLSubCond( "tq.%s = %%s" % field, tqMatchDict[ field ] ) ) #Type of single value conditions for field in ( 'CPUTime', 'Setup' ): if field in tqMatchDict: if field in ( 'CPUTime' ): sqlCondList.append( self.__generateSQLSubCond( "tq.%s <= %%s" % field, tqMatchDict[ field ] ) ) else: sqlCondList.append( self.__generateSQLSubCond( "tq.%s = %%s" % field, tqMatchDict[ field ] ) ) #Match multi value fields for field in self.__multiValueMatchFields: #It has to be %ss , with an 's' at the end because the columns names # are plural and match options are singular if field in tqMatchDict and tqMatchDict[ field ]: tableN, fullTableN = self.__generateTablesName( sqlTables, field ) sqlMultiCondList = [] if field != 'GridCE' or 'Site' in tqMatchDict: # Jobs for masked sites can be matched if they specified a GridCE # Site is removed from tqMatchDict if the Site is mask. In this case we want # that the GridCE matches explicetly so the COUNT can not be 0. In this case we skip this # condition sqlMultiCondList.append( "( SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = tq.TQId ) = 0" % ( fullTableN, fullTableN, fullTableN ) ) csql = self.__generateSQLSubCond( "%%s IN ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ field ] ) sqlMultiCondList.append( csql ) sqlCondList.append( "( %s )" % " OR ".join( sqlMultiCondList ) ) #In case of Site, check it's not in job banned sites if field in self.__bannedJobMatchFields: fullTableN = '`tq_TQToBanned%ss`' % field csql = self.__generateSQLSubCond( "%%s not in ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ field ], boolOp = 'AND' ) sqlCondList.append( csql ) #Resource banning bannedField = "Banned%s" % field if bannedField in tqMatchDict and tqMatchDict[ bannedField ]: fullTableN = '`tq_TQTo%ss`' % field csql = self.__generateSQLSubCond( "%%s not in ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ bannedField ], boolOp = 'AND' ) sqlCondList.append( csql ) #For certain fields, the require is strict. If it is not in the tqMatchDict, the job cannot require it for field in self.__strictRequireMatchFields: if field in tqMatchDict: continue fullTableN = '`tq_TQTo%ss`' % field sqlCondList.append( "( SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = tq.TQId ) = 0" % ( fullTableN, fullTableN, fullTableN ) ) # Add extra conditions if negativeCond: sqlCondList.append( self.__generateNotSQL( sqlTables, negativeCond ) ) #Generate the final query string tqSqlCmd = "SELECT tq.TQId, tq.OwnerDN, tq.OwnerGroup FROM `tq_TaskQueues` tq WHERE %s" % ( " AND ".join( sqlCondList ) ) #Apply priorities tqSqlCmd = "%s ORDER BY RAND() / tq.Priority ASC" % tqSqlCmd #Do we want a limit? if numQueuesToGet: tqSqlCmd = "%s LIMIT %s" % ( tqSqlCmd, numQueuesToGet ) return S_OK( tqSqlCmd ) def deleteJob( self, jobId, connObj = False ): """ Delete a job from the task queues Return S_OK( True/False ) / S_ERROR """ self.log.info( "Deleting job %s" % jobId ) if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't delete job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] retVal = self._query( "SELECT t.TQId, t.OwnerDN, t.OwnerGroup FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE j.JobId = %s AND t.TQId = j.TQId" % jobId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not get job from task queue %s: %s" % ( jobId, retVal[ 'Message' ] ) ) data = retVal[ 'Value' ] if not data: return S_OK( False ) tqId, tqOwnerDN, tqOwnerGroup = data[0] retVal = self._update( "DELETE FROM `tq_Jobs` WHERE JobId = %s" % jobId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete job from task queue %s: %s" % ( jobId, retVal[ 'Message' ] ) ) result = retVal[ 'Value' ] if retVal[ 'Value' ] == 0: #No job deleted return S_OK( False ) retries = 10 #Always return S_OK() because job has already been taken out from the TQ self.__deleteTQWithDelay.add( tqId, 300, ( tqId, tqOwnerDN, tqOwnerGroup ) ) return S_OK( True ) def getTaskQueueForJob( self, jobId, connObj = False ): """ Return TaskQueue for a given Job Return S_OK( [TaskQueueID] ) / S_ERROR """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't get TQ for job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] retVal = self._query( 'SELECT TQId FROM `tq_Jobs` WHERE JobId = %s ' % jobId, conn = connObj ) if not retVal[ 'OK' ]: return retVal if not retVal['Value']: return S_ERROR( 'Not in TaskQueues' ) return S_OK( retVal['Value'][0][0] ) def getTaskQueueForJobs( self, jobIDs, connObj = False ): """ Return TaskQueues for a given list of Jobs """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't get TQs for a job list: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] jobString = ','.join( [ str( x ) for x in jobIDs ] ) retVal = self._query( 'SELECT JobId,TQId FROM `tq_Jobs` WHERE JobId in (%s) ' % jobString, conn = connObj ) if not retVal[ 'OK' ]: return retVal if not retVal['Value']: return S_ERROR( 'Not in TaskQueues' ) resultDict = {} for jobID, TQID in retVal['Value']: resultDict[int( jobID )] = int( TQID ) return S_OK( resultDict ) def __getOwnerForTaskQueue( self, tqId, connObj = False ): retVal = self._query( "SELECT OwnerDN, OwnerGroup from `tq_TaskQueues` WHERE TQId=%s" % tqId, conn = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if len( data ) == 0: return S_OK( False ) return S_OK( retVal[ 'Value' ][0] ) def __deleteTQIfEmpty( self, args ): print "[ADRILETE] PRE" ( tqId, tqOwnerDN, tqOwnerGroup ) = args print "[ADRILETE] POST" retries = 3 while retries: retries -= 1 result = self.deleteTaskQueueIfEmpty( tqId, tqOwnerDN, tqOwnerGroup ) if result[ 'OK' ]: print "[ADRILETE] OK" return gLogger.error( "Could not delete TQ %s: %s" % ( tqId, result[ 'Message' ] ) ) def deleteTaskQueueIfEmpty( self, tqId, tqOwnerDN = False, tqOwnerGroup = False, connObj = False ): """ Try to delete a task queue if its empty """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not tqOwnerDN or not tqOwnerGroup: retVal = self.__getOwnerForTaskQueue( tqId, connObj = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if not data: return S_OK( False ) tqOwnerDN, tqOwnerGroup = data sqlCmd = "DELETE FROM `tq_TaskQueues` WHERE Enabled >= 1 AND `tq_TaskQueues`.TQId = %s" % tqId sqlCmd = "%s AND `tq_TaskQueues`.TQId not in ( SELECT DISTINCT TQId from `tq_Jobs` )" % sqlCmd retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) delTQ = retVal[ 'Value' ] if delTQ > 0: for mvField in self.__multiValueDefFields: retVal = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId = %s" % ( mvField, tqId ), conn = connObj ) if not retVal[ 'OK' ]: return retVal self.recalculateTQSharesForEntity( tqOwnerDN, tqOwnerGroup, connObj = connObj ) self.log.info( "Deleted empty and enabled TQ %s" % tqId ) return S_OK( True ) return S_OK( False ) def deleteTaskQueue( self, tqId, tqOwnerDN = False, tqOwnerGroup = False, connObj = False ): """ Try to delete a task queue even if it has jobs """ self.log.info( "Deleting TQ %s" % tqId ) if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not tqOwnerDN or not tqOwnerGroup: retVal = self.__getOwnerForTaskQueue( tqId, connObj = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if not data: return S_OK( False ) tqOwnerDN, tqOwnerGroup = data sqlCmd = "DELETE FROM `tq_TaskQueues` WHERE `tq_TaskQueues`.TQId = %s" % tqId retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) delTQ = retVal[ 'Value' ] sqlCmd = "DELETE FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s" % tqId retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) for mvField in self.__multiValueDefFields: retVal = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId = %s" % tqId, conn = connObj ) if not retVal[ 'OK' ]: return retVal if delTQ > 0: self.recalculateTQSharesForEntity( tqOwnerDN, tqOwnerGroup, connObj = connObj ) return S_OK( True ) return S_OK( False ) def getMatchingTaskQueues( self, tqMatchDict, negativeCond = False ): """ rename to have the same method as exposed in the Matcher """ return self.retrieveTaskQueuesThatMatch( tqMatchDict, negativeCond = negativeCond ) def getNumTaskQueues( self ): """ Get the number of task queues in the system """ sqlCmd = "SELECT COUNT( TQId ) FROM `tq_TaskQueues`" retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return retVal return S_OK( retVal[ 'Value' ][0][0] ) def retrieveTaskQueuesThatMatch( self, tqMatchDict, negativeCond = False ): """ Get the info of the task queues that match a resource """ result = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = 0, negativeCond = negativeCond ) if not result[ 'OK' ]: return result return self.retrieveTaskQueues( [ tqTuple[0] for tqTuple in result[ 'Value' ] ] ) def retrieveTaskQueues( self, tqIdList = False ): """ Get all the task queues """ sqlSelectEntries = [ "`tq_TaskQueues`.TQId", "`tq_TaskQueues`.Priority", "COUNT( `tq_Jobs`.TQId )" ] sqlGroupEntries = [ "`tq_TaskQueues`.TQId", "`tq_TaskQueues`.Priority" ] for field in self.__singleValueDefFields: sqlSelectEntries.append( "`tq_TaskQueues`.%s" % field ) sqlGroupEntries.append( "`tq_TaskQueues`.%s" % field ) sqlCmd = "SELECT %s FROM `tq_TaskQueues`, `tq_Jobs`" % ", ".join( sqlSelectEntries ) sqlTQCond = "AND Enabled >= 1" if tqIdList != False: if len( tqIdList ) == 0: return S_OK( {} ) else: sqlTQCond += " AND `tq_TaskQueues`.TQId in ( %s )" % ", ".join( [ str( id ) for id in tqIdList ] ) sqlCmd = "%s WHERE `tq_TaskQueues`.TQId = `tq_Jobs`.TQId %s GROUP BY %s" % ( sqlCmd, sqlTQCond, ", ".join( sqlGroupEntries ) ) retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve task queues info: %s" % retVal[ 'Message' ] ) tqData = {} for record in retVal[ 'Value' ]: tqId = record[0] tqData[ tqId ] = { 'Priority' : record[1], 'Jobs' : record[2] } record = record[3:] for iP in range( len( self.__singleValueDefFields ) ): tqData[ tqId ][ self.__singleValueDefFields[ iP ] ] = record[ iP ] tqNeedCleaning = False for field in self.__multiValueDefFields: table = "`tq_TQTo%s`" % field sqlCmd = "SELECT %s.TQId, %s.Value FROM %s" % ( table, table, table ) retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve task queues field % info: %s" % ( field, retVal[ 'Message' ] ) ) for record in retVal[ 'Value' ]: tqId = record[0] value = record[1] if not tqId in tqData: if tqIdList == False or tqId in tqIdList: self.log.warn( "Task Queue %s is defined in field %s but does not exist, triggering a cleaning" % ( tqId, field ) ) tqNeedCleaning = True else: if field not in tqData[ tqId ]: tqData[ tqId ][ field ] = [] tqData[ tqId ][ field ].append( value ) if tqNeedCleaning: self.cleanOrphanedTaskQueues() return S_OK( tqData ) def __updateGlobalShares( self ): """ Update internal structure for shares """ #Update group shares self.__groupShares = self.getGroupShares() #Apply corrections if enabled if self.isSharesCorrectionEnabled(): result = self.getGroupsInTQs() if not result[ 'OK' ]: self.log.error( "Could not get groups in the TQs", result[ 'Message' ] ) activeGroups = result[ 'Value' ] newShares = {} for group in activeGroups: if group in self.__groupShares: newShares[ group ] = self.__groupShares[ group ] newShares = self.__sharesCorrector.correctShares( newShares ) for group in self.__groupShares: if group in newShares: self.__groupShares[ group ] = newShares[ group ] def recalculateTQSharesForAll( self ): """ Recalculate all priorities for TQ's """ if self.isSharesCorrectionEnabled(): self.log.info( "Updating correctors state" ) self.__sharesCorrector.update() self.__updateGlobalShares() self.log.info( "Recalculating shares for all TQs" ) retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] result = self._query( "SELECT DISTINCT( OwnerGroup ) FROM `tq_TaskQueues`" ) if not result[ 'OK' ]: return result for group in [ r[0] for r in result[ 'Value' ] ]: self.recalculateTQSharesForEntity( "all", group ) return S_OK() def recalculateTQSharesForEntity( self, userDN, userGroup, connObj = False ): """ Recalculate the shares for a userDN/userGroup combo """ self.log.info( "Recalculating shares for %s@%s TQs" % ( userDN, userGroup ) ) if userGroup in self.__groupShares: share = self.__groupShares[ userGroup ] else: share = float( DEFAULT_GROUP_SHARE ) if Properties.JOB_SHARING in CS.getPropertiesForGroup( userGroup ): #If group has JobSharing just set prio for that entry, userDN is irrelevant return self.__setPrioritiesForEntity( userDN, userGroup, share, connObj = connObj ) selSQL = "SELECT OwnerDN, COUNT(OwnerDN) FROM `tq_TaskQueues` WHERE OwnerGroup='%s' GROUP BY OwnerDN" % ( userGroup ) result = self._query( selSQL, conn = connObj ) if not result[ 'OK' ]: return result #Get owners in this group and the amount of times they appear data = [ ( r[0], r[1] ) for r in result[ 'Value' ] if r ] numOwners = len( data ) #If there are no owners do now if numOwners == 0: return S_OK() #Split the share amongst the number of owners share /= numOwners entitiesShares = dict( [ ( row[0], share ) for row in data ] ) #If corrector is enabled let it work it's magic if self.isSharesCorrectionEnabled(): entitiesShares = self.__sharesCorrector.correctShares( entitiesShares, group = userGroup ) #Keep updating owners = dict( data ) #IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified #(The number of owners didn't change) if userDN in owners and owners[ userDN ] > 1: return self.__setPrioritiesForEntity( userDN, userGroup, entitiesShares[ userDN ], connObj = connObj ) #Oops the number of owners may have changed so we recalculate the prio for all owners in the group for userDN in owners: self.__setPrioritiesForEntity( userDN, userGroup, entitiesShares[ userDN ], connObj = connObj ) return S_OK() def __setPrioritiesForEntity( self, userDN, userGroup, share, connObj = False, consolidationFunc = "AVG" ): """ Set the priority for a userDN/userGroup combo given a splitted share """ self.log.info( "Setting priorities to %s@%s TQs" % ( userDN, userGroup ) ) tqCond = [ "t.OwnerGroup='%s'" % userGroup ] allowBgTQs = gConfig.getValue( "/Registry/Groups/%s/AllowBackgroundTQs" % userGroup, False ) if Properties.JOB_SHARING not in CS.getPropertiesForGroup( userGroup ): tqCond.append( "t.OwnerDN='%s'" % userDN ) tqCond.append( "t.TQId = j.TQId" ) if consolidationFunc == 'AVG': selectSQL = "SELECT j.TQId, SUM( j.RealPriority )/COUNT(j.RealPriority) FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE " elif consolidationFunc == 'SUM': selectSQL = "SELECT j.TQId, SUM( j.RealPriority ) FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE " else: return S_ERROR( "Unknown consolidation func %s for setting priorities" % consolidationFunc ) selectSQL += " AND ".join( tqCond ) selectSQL += " GROUP BY t.TQId" result = self._query( selectSQL, conn = connObj ) if not result[ 'OK' ]: return result tqDict = dict( result[ 'Value' ] ) if len( tqDict ) == 0: return S_OK() #Calculate Sum of priorities totalPrio = 0 for k in tqDict: if tqDict[k] > 0.1 or not allowBgTQs: totalPrio += tqDict[ k ] #Group by priorities prioDict = {} for tqId in tqDict: if tqDict[ tqId ] > 0.1 or not allowBgTQs: prio = ( share / totalPrio ) * tqDict[ tqId ] else: prio = TQ_MIN_SHARE prio = max( prio, TQ_MIN_SHARE ) if prio not in prioDict: prioDict[ prio ] = [] prioDict[ prio ].append( tqId ) #Execute updates for prio in prioDict: tqList = ", ".join( [ str( tqId ) for tqId in prioDict[ prio ] ] ) updateSQL = "UPDATE `tq_TaskQueues` SET Priority=%.4f WHERE TQId in ( %s )" % ( prio, tqList ) self._update( updateSQL, conn = connObj ) return S_OK() def getGroupShares( self ): """ Get all the shares as a DICT """ result = gConfig.getSections( "/Registry/Groups" ) if result[ 'OK' ]: groups = result[ 'Value' ] else: groups = [] shares = {} for group in groups: shares[ group ] = gConfig.getValue( "/Registry/Groups/%s/JobShare" % group, DEFAULT_GROUP_SHARE ) return shares def propagateTQSharesIfChanged( self ): """ If the shares have changed in the CS, recalculate priorities """ shares = self.getGroupShares() if shares == self.__groupShares: return S_OK() self.__groupShares = shares return self.recalculateTQSharesForAll() def modifyJobsPriorities( self, jobPrioDict ): """ Modify the priority for some jobs """ for jId in jobPrioDict: jobPrioDict[jId] = int( jobPrioDict[jId] ) maxJobsInQuery = 1000 jobsList = sorted( jobPrioDict ) prioDict = {} for jId in jobsList: prio = jobPrioDict[ jId ] if not prio in prioDict: prioDict[ prio ] = [] prioDict[ prio ].append( str( jId ) ) updated = 0 for prio in prioDict: jobsList = prioDict[ prio ] for i in range( maxJobsInQuery, 0, len( jobsList ) ): jobs = ",".join( jobsList[ i : i + maxJobsInQuery ] ) updateSQL = "UPDATE `tq_Jobs` SET `Priority`=%s, `RealPriority`=%f WHERE `JobId` in ( %s )" % ( prio, self.__hackJobPriority( prio ), jobs ) result = self._update( updateSQL ) if not result[ 'OK' ]: return result updated += result[ 'Value' ] if not updated: return S_OK() return self.recalculateTQSharesForAll()
class TaskQueueDB( DB ): def __init__( self, maxQueueSize = 10 ): random.seed() DB.__init__( self, 'TaskQueueDB', 'WorkloadManagement/TaskQueueDB', maxQueueSize ) self.__multiValueDefFields = ( 'Sites', 'GridCEs', 'GridMiddlewares', 'BannedSites', 'Platforms', 'PilotTypes', 'SubmitPools', 'JobTypes' ) self.__multiValueMatchFields = ( 'GridCE', 'Site', 'GridMiddleware', 'Platform', 'PilotType', 'SubmitPool', 'JobType' ) self.__bannedJobMatchFields = ( 'Site', ) self.__strictRequireMatchFields = ( 'SubmitPool', 'Platform', 'PilotType' ) self.__singleValueDefFields = ( 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ) self.__mandatoryMatchFields = ( 'Setup', 'CPUTime' ) self.__defaultCPUSegments = maxCPUSegments self.__maxMatchRetry = 3 self.__jobPriorityBoundaries = ( 0.001, 10 ) self.__groupShares = {} self.__deleteTQWithDelay = DictCache( self.__deleteTQIfEmpty ) self.__opsHelper = Operations() self.__ensureInsertionIsSingle = False self.__sharesCorrector = SharesCorrector( self.__opsHelper ) result = self.__initializeDB() if not result[ 'OK' ]: raise Exception( "Can't create tables: %s" % result[ 'Message' ] ) def enableAllTaskQueues( self ): """ Enable all Task queues """ return self.updateFields( "tq_TaskQueues", updateDict = { "Enabled" :"1" } ) def findOrphanJobs( self ): """ Find jobs that are not in any task queue """ result = self._query( "select JobID from tq_Jobs WHERE TQId not in (SELECT TQId from tq_TaskQueues)" ) if not result[ 'OK' ]: return result return S_OK( [ row[0] for row in result[ 'Value' ] ] ) def isSharesCorrectionEnabled( self ): return self.__getCSOption( "EnableSharesCorrection", False ) def getSingleValueTQDefFields( self ): return self.__singleValueDefFields def getMultiValueTQDefFields( self ): return self.__multiValueDefFields def getMultiValueMatchFields( self ): return self.__multiValueMatchFields def __getCSOption( self, optionName, defValue ): return self.__opsHelper.getValue( "JobScheduling/%s" % optionName, defValue ) def getPrivatePilots( self ): return self.__getCSOption( "PrivatePilotTypes", [ 'private' ] ) def getValidPilotTypes( self ): return self.__getCSOption( "AllPilotTypes", [ 'private' ] ) def __initializeDB( self ): """ Create the tables """ result = self._query( "show tables" ) if not result[ 'OK' ]: return result tablesInDB = [ t[0] for t in result[ 'Value' ] ] tablesToCreate = {} self.__tablesDesc = {} self.__tablesDesc[ 'tq_TaskQueues' ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED AUTO_INCREMENT NOT NULL', 'OwnerDN' : 'VARCHAR(255) NOT NULL', 'OwnerGroup' : 'VARCHAR(32) NOT NULL', 'Setup' : 'VARCHAR(32) NOT NULL', 'CPUTime' : 'BIGINT UNSIGNED NOT NULL', 'Priority' : 'FLOAT NOT NULL', 'Enabled' : 'TINYINT(1) NOT NULL DEFAULT 0' }, 'PrimaryKey' : 'TQId', 'Indexes': { 'TQOwner': [ 'OwnerDN', 'OwnerGroup', 'Setup', 'CPUTime' ] } } self.__tablesDesc[ 'tq_Jobs' ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED NOT NULL', 'JobId' : 'INTEGER UNSIGNED NOT NULL', 'Priority' : 'INTEGER UNSIGNED NOT NULL', 'RealPriority' : 'FLOAT NOT NULL' }, 'PrimaryKey' : 'JobId', 'Indexes': { 'TaskIndex': [ 'TQId' ] }, } for multiField in self.__multiValueDefFields: tableName = 'tq_TQTo%s' % multiField self.__tablesDesc[ tableName ] = { 'Fields' : { 'TQId' : 'INTEGER UNSIGNED NOT NULL', 'Value' : 'VARCHAR(64) NOT NULL', }, 'Indexes': { 'TaskIndex': [ 'TQId' ], '%sIndex' % multiField: [ 'Value' ] }, } for tableName in self.__tablesDesc: if not tableName in tablesInDB: tablesToCreate[ tableName ] = self.__tablesDesc[ tableName ] return self._createTables( tablesToCreate ) def getGroupsInTQs( self ): cmdSQL = "SELECT DISTINCT( OwnerGroup ) FROM `tq_TaskQueues`" result = self._query( cmdSQL ) if not result[ 'OK' ]: return result return S_OK( [ row[0] for row in result[ 'Value' ] ] ) def forceRecreationOfTables( self ): dropSQL = "DROP TABLE IF EXISTS %s" % ", ".join( self.__tablesDesc ) result = self._update( dropSQL ) if not result[ 'OK' ]: return result return self._createTables( self.__tablesDesc ) def __strDict( self, dDict ): lines = [] for key in sorted( dDict ): lines.append( " %s" % key ) value = dDict[ key ] if type( value ) in ( types.ListType, types.TupleType ): lines.extend( [ " %s" % v for v in value ] ) else: lines.append( " %s" % str( value ) ) return "{\n%s\n}" % "\n".join( lines ) def fitCPUTimeToSegments( self, cpuTime ): """ Fit the CPU time to the valid segments """ maxCPUSegments = self.__getCSOption( "taskQueueCPUTimeIntervals", self.__defaultCPUSegments ) try: maxCPUSegments = [ int( seg ) for seg in maxCPUSegments ] #Check segments in the CS last = 0 for cpuS in maxCPUSegments: if cpuS <= last: maxCPUSegments = self.__defaultCPUSegments break last = cpuS except: maxCPUSegments = self.__defaultCPUSegments #Map to a segment for iP in range( len( maxCPUSegments ) ): cpuSegment = maxCPUSegments[ iP ] if cpuTime <= cpuSegment: return cpuSegment return maxCPUSegments[-1] def _checkTaskQueueDefinition( self, tqDefDict ): """ Check a task queue definition dict is valid """ # Confine the LHCbPlatform legacy option here, use Platform everywhere else # until the LHCbPlatform is no more used in the TaskQueueDB if 'LHCbPlatforms' in tqDefDict and not "Platforms" in tqDefDict: tqDefDict['Platforms'] = tqDefDict['LHCbPlatforms'] for field in self.__singleValueDefFields: if field not in tqDefDict: return S_ERROR( "Missing mandatory field '%s' in task queue definition" % field ) fieldValueType = type( tqDefDict[ field ] ) if field in [ "CPUTime" ]: if fieldValueType not in ( types.IntType, types.LongType ): return S_ERROR( "Mandatory field %s value type is not valid: %s" % ( field, fieldValueType ) ) else: if fieldValueType not in ( types.StringType, types.UnicodeType ): return S_ERROR( "Mandatory field %s value type is not valid: %s" % ( field, fieldValueType ) ) result = self._escapeString( tqDefDict[ field ] ) if not result[ 'OK' ]: return result tqDefDict[ field ] = result[ 'Value' ] for field in self.__multiValueDefFields: if field not in tqDefDict: continue fieldValueType = type( tqDefDict[ field ] ) if fieldValueType not in ( types.ListType, types.TupleType ): return S_ERROR( "Multi value field %s value type is not valid: %s" % ( field, fieldValueType ) ) result = self._escapeValues( tqDefDict[ field ] ) if not result[ 'OK' ]: return result tqDefDict[ field ] = result[ 'Value' ] #FIXME: This is not used if 'PrivatePilots' in tqDefDict: validPilotTypes = self.getValidPilotTypes() for pilotType in tqDefDict[ 'PrivatePilots' ]: if pilotType not in validPilotTypes: return S_ERROR( "PilotType %s is invalid" % pilotType ) return S_OK( tqDefDict ) def _checkMatchDefinition( self, tqMatchDict ): """ Check a task queue match dict is valid """ def travelAndCheckType( value, validTypes, escapeValues = True ): valueType = type( value ) if valueType in ( types.ListType, types.TupleType ): for subValue in value: subValueType = type( subValue ) if subValueType not in validTypes: return S_ERROR( "List contained type %s is not valid -> %s" % ( subValueType, validTypes ) ) if escapeValues: return self._escapeValues( value ) return S_OK( value ) else: if valueType not in validTypes: return S_ERROR( "Type %s is not valid -> %s" % ( valueType, validTypes ) ) if escapeValues: return self._escapeString( value ) return S_OK( value ) # Confine the LHCbPlatform legacy option here, use Platform everywhere else # until the LHCbPlatform is no more used in the TaskQueueDB if 'LHCbPlatform' in tqMatchDict and not "Platform" in tqMatchDict: tqMatchDict['Platform'] = tqMatchDict['LHCbPlatform'] for field in self.__singleValueDefFields: if field not in tqMatchDict: if field in self.__mandatoryMatchFields: return S_ERROR( "Missing mandatory field '%s' in match request definition" % field ) continue fieldValue = tqMatchDict[ field ] if field in [ "CPUTime" ]: result = travelAndCheckType( fieldValue, ( types.IntType, types.LongType ), escapeValues = False ) else: result = travelAndCheckType( fieldValue, ( types.StringType, types.UnicodeType ) ) if not result[ 'OK' ]: return S_ERROR( "Match definition field %s failed : %s" % ( field, result[ 'Message' ] ) ) tqMatchDict[ field ] = result[ 'Value' ] #Check multivalue for multiField in self.__multiValueMatchFields: for field in ( multiField, "Banned%s" % multiField ): if field in tqMatchDict: fieldValue = tqMatchDict[ field ] result = travelAndCheckType( fieldValue, ( types.StringType, types.UnicodeType ) ) if not result[ 'OK' ]: return S_ERROR( "Match definition field %s failed : %s" % ( field, result[ 'Message' ] ) ) tqMatchDict[ field ] = result[ 'Value' ] return S_OK( tqMatchDict ) def __createTaskQueue( self, tqDefDict, priority = 1, connObj = False ): """ Create a task queue Returns S_OK( tqId ) / S_ERROR """ if not connObj: result = self._getConnection() if not result[ 'OK' ]: return S_ERROR( "Can't create task queue: %s" % result[ 'Message' ] ) connObj = result[ 'Value' ] tqDefDict[ 'CPUTime' ] = self.fitCPUTimeToSegments( tqDefDict[ 'CPUTime' ] ) sqlSingleFields = [ 'TQId', 'Priority' ] sqlValues = [ "0", str( priority ) ] for field in self.__singleValueDefFields: sqlSingleFields.append( field ) sqlValues.append( tqDefDict[ field ] ) #Insert the TQ Disabled sqlSingleFields.append( "Enabled" ) sqlValues.append( "0" ) cmd = "INSERT INTO tq_TaskQueues ( %s ) VALUES ( %s )" % ( ", ".join( sqlSingleFields ), ", ".join( [ str( v ) for v in sqlValues ] ) ) result = self._update( cmd, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Can't insert TQ in DB", result[ 'Value' ] ) return result if 'lastRowId' in result: tqId = result['lastRowId'] else: result = self._query( "SELECT LAST_INSERT_ID()", conn = connObj ) if not result[ 'OK' ]: self.cleanOrphanedTaskQueues( connObj = connObj ) return S_ERROR( "Can't determine task queue id after insertion" ) tqId = result[ 'Value' ][0][0] for field in self.__multiValueDefFields: if field not in tqDefDict: continue values = List.uniqueElements( [ value for value in tqDefDict[ field ] if value.strip() ] ) if not values: continue cmd = "INSERT INTO `tq_TQTo%s` ( TQId, Value ) VALUES " % field cmd += ", ".join( [ "( %s, %s )" % ( tqId, str( value ) ) for value in values ] ) result = self._update( cmd, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Failed to insert %s condition" % field, result[ 'Message' ] ) self.cleanOrphanedTaskQueues( connObj = connObj ) return S_ERROR( "Can't insert values %s for field %s: %s" % ( str( values ), field, result[ 'Message' ] ) ) self.log.info( "Created TQ %s" % tqId ) return S_OK( tqId ) def cleanOrphanedTaskQueues( self, connObj = False ): """ Delete all empty task queues """ self.log.info( "Cleaning orphaned TQs" ) result = self._update( "DELETE FROM `tq_TaskQueues` WHERE Enabled >= 1 AND TQId not in ( SELECT DISTINCT TQId from `tq_Jobs` )", conn = connObj ) if not result[ 'OK' ]: return result for mvField in self.__multiValueDefFields: result = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId not in ( SELECT DISTINCT TQId from `tq_TaskQueues` )" % mvField, conn = connObj ) if not result[ 'OK' ]: return result return S_OK() def setTaskQueueState( self, tqId, enabled = True, connObj = False ): if enabled: enabled = "+ 1" else: enabled = "- 1" upSQL = "UPDATE `tq_TaskQueues` SET Enabled = Enabled %s WHERE TQId=%d" % ( enabled, tqId ) result = self._update( upSQL, conn = connObj ) if not result[ 'OK' ]: self.log.error( "Error setting TQ state", "TQ %s State %s: %s" % ( tqId, enabled, result[ 'Message' ] ) ) return result updated = result['Value'] > 0 if updated: self.log.info( "Set enabled = %s for TQ %s" % ( enabled, tqId ) ) return S_OK( updated ) def __hackJobPriority( self, jobPriority ): jobPriority = min( max( int( jobPriority ), self.__jobPriorityBoundaries[0] ), self.__jobPriorityBoundaries[1] ) if jobPriority == self.__jobPriorityBoundaries[0]: return 10 ** ( -5 ) if jobPriority == self.__jobPriorityBoundaries[1]: return 10 ** 6 return jobPriority def insertJob( self, jobId, tqDefDict, jobPriority, skipTQDefCheck = False, numRetries = 10 ): """ Insert a job in a task queue Returns S_OK( tqId ) / S_ERROR """ try: test = long( jobId ) except: return S_ERROR( "JobId is not a number!" ) retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not skipTQDefCheck: tqDefDict = dict( tqDefDict ) retVal = self._checkTaskQueueDefinition( tqDefDict ) if not retVal[ 'OK' ]: self.log.error( "TQ definition check failed", retVal[ 'Message' ] ) return retVal tqDefDict = retVal[ 'Value' ] tqDefDict[ 'CPUTime' ] = self.fitCPUTimeToSegments( tqDefDict[ 'CPUTime' ] ) self.log.info( "Inserting job %s with requirements: %s" % ( jobId, self.__strDict( tqDefDict ) ) ) retVal = self.__findAndDisableTaskQueue( tqDefDict, skipDefinitionCheck = True, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqInfo = retVal[ 'Value' ] newTQ = False if not tqInfo[ 'found' ]: self.log.info( "Creating a TQ for job %s" % jobId ) retVal = self.__createTaskQueue( tqDefDict, 1, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqId = retVal[ 'Value' ] newTQ = True else: tqId = tqInfo[ 'tqId' ] self.log.info( "Found TQ %s for job %s requirements" % ( tqId, jobId ) ) try: result = self.__insertJobInTaskQueue( jobId, tqId, int( jobPriority ), checkTQExists = False, connObj = connObj ) if not result[ 'OK' ]: self.log.error( "Error inserting job in TQ", "Job %s TQ %s: %s" % ( jobId, tqId, result[ 'Message' ] ) ) return result if newTQ: self.recalculateTQSharesForEntity( tqDefDict[ 'OwnerDN' ], tqDefDict[ 'OwnerGroup' ], connObj = connObj ) finally: self.setTaskQueueState( tqId, True ) return S_OK() def __insertJobInTaskQueue( self, jobId, tqId, jobPriority, checkTQExists = True, connObj = False ): """ Insert a job in a given task queue """ self.log.info( "Inserting job %s in TQ %s with priority %s" % ( jobId, tqId, jobPriority ) ) if not connObj: result = self._getConnection() if not result[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % result[ 'Message' ] ) connObj = result[ 'Value' ] if checkTQExists: result = self._query( "SELECT tqId FROM `tq_TaskQueues` WHERE TQId = %s" % tqId, conn = connObj ) if not result[ 'OK' ] or len ( result[ 'Value' ] ) == 0: return S_OK( "Can't find task queue with id %s: %s" % ( tqId, result[ 'Message' ] ) ) hackedPriority = self.__hackJobPriority( jobPriority ) result = self._update( "INSERT INTO tq_Jobs ( TQId, JobId, Priority, RealPriority ) VALUES ( %s, %s, %s, %f )" % ( tqId, jobId, jobPriority, hackedPriority ), conn = connObj ) if not result[ 'OK' ] and result[ 'Message' ].find( "Duplicate entry" ) == -1: return result return S_OK() def __generateTQFindSQL( self, tqDefDict, skipDefinitionCheck = False, connObj = False ): """ Find a task queue that has exactly the same requirements """ if not skipDefinitionCheck: tqDefDict = dict( tqDefDict ) result = self._checkTaskQueueDefinition( tqDefDict ) if not result[ 'OK' ]: return result tqDefDict = result[ 'Value' ] sqlCondList = [] for field in self.__singleValueDefFields: sqlCondList.append( "`tq_TaskQueues`.%s = %s" % ( field, tqDefDict[ field ] ) ) #MAGIC SUBQUERIES TO ENSURE STRICT MATCH for field in self.__multiValueDefFields: tableName = '`tq_TQTo%s`' % field if field in tqDefDict and tqDefDict[ field ]: firstQuery = "SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = `tq_TaskQueues`.TQId" % ( tableName, tableName, tableName ) grouping = "GROUP BY %s.TQId" % tableName valuesList = List.uniqueElements( [ value.strip() for value in tqDefDict[ field ] if value.strip() ] ) numValues = len( valuesList ) secondQuery = "%s AND %s.Value in (%s)" % ( firstQuery, tableName, ",".join( [ "%s" % str( value ) for value in valuesList ] ) ) sqlCondList.append( "%s = (%s %s)" % ( numValues, firstQuery, grouping ) ) sqlCondList.append( "%s = (%s %s)" % ( numValues, secondQuery, grouping ) ) else: sqlCondList.append( "`tq_TaskQueues`.TQId not in ( SELECT DISTINCT %s.TQId from %s )" % ( tableName, tableName ) ) #END MAGIC: That was easy ;) return S_OK( " AND ".join( sqlCondList ) ) def __findAndDisableTaskQueue( self, tqDefDict, skipDefinitionCheck = False, retries = 10, connObj = False ): """ Disable and find TQ """ for i in range( retries ): result = self.findTaskQueue( tqDefDict, skipDefinitionCheck = skipDefinitionCheck, connObj = connObj ) if not result[ 'OK' ]: return result data = result[ 'Value' ] if not data[ 'found' ]: return result result = self._update( "UPDATE `tq_TaskQueues` SET Enabled = Enabled - 1 WHERE TQId = %d" % data[ 'tqId' ] ) if not result[ 'OK' ]: return result if result[ 'Value' ] > 0: return S_OK( data ) return S_ERROR( "Could not disable TQ" ) def findTaskQueue( self, tqDefDict, skipDefinitionCheck = False, connObj = False ): """ Find a task queue that has exactly the same requirements """ result = self.__generateTQFindSQL( tqDefDict, skipDefinitionCheck = skipDefinitionCheck, connObj = connObj ) if not result[ 'OK' ]: return result sqlCmd = "SELECT `tq_TaskQueues`.TQId FROM `tq_TaskQueues` WHERE" sqlCmd = "%s %s" % ( sqlCmd, result[ 'Value' ] ) result = self._query( sqlCmd, conn = connObj ) if not result[ 'OK' ]: return S_ERROR( "Can't find task queue: %s" % result[ 'Message' ] ) data = result[ 'Value' ] if len( data ) == 0: return S_OK( { 'found' : False } ) if len( data ) > 1: gLogger.warn( "Found two task queues for the same requirements", self.__strDict( tqDefDict ) ) return S_OK( { 'found' : True, 'tqId' : data[0][0] } ) def matchAndGetJob( self, tqMatchDict, numJobsPerTry = 50, numQueuesPerTry = 10, negativeCond = {} ): """ Match a job """ #Make a copy to avoid modification of original if escaping needs to be done tqMatchDict = dict( tqMatchDict ) self.log.info( "Starting match for requirements", self.__strDict( tqMatchDict ) ) retVal = self._checkMatchDefinition( tqMatchDict ) if not retVal[ 'OK' ]: self.log.error( "TQ match request check failed", retVal[ 'Message' ] ) return retVal retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't connect to DB: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] preJobSQL = "SELECT `tq_Jobs`.JobId, `tq_Jobs`.TQId FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s AND `tq_Jobs`.Priority = %s" prioSQL = "SELECT `tq_Jobs`.Priority FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s ORDER BY RAND() / `tq_Jobs`.RealPriority ASC LIMIT 1" postJobSQL = " ORDER BY `tq_Jobs`.JobId ASC LIMIT %s" % numJobsPerTry for matchTry in range( self.__maxMatchRetry ): if 'JobID' in tqMatchDict: # A certain JobID is required by the resource, so all TQ are to be considered retVal = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = 0, skipMatchDictDef = True, connObj = connObj ) preJobSQL = "%s AND `tq_Jobs`.JobId = %s " % ( preJobSQL, tqMatchDict['JobID'] ) else: retVal = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = numQueuesPerTry, skipMatchDictDef = True, negativeCond = negativeCond, connObj = connObj ) if not retVal[ 'OK' ]: return retVal tqList = retVal[ 'Value' ] if len( tqList ) == 0: self.log.info( "No TQ matches requirements" ) return S_OK( { 'matchFound' : False, 'tqMatch' : tqMatchDict } ) for tqId, tqOwnerDN, tqOwnerGroup in tqList: self.log.info( "Trying to extract jobs from TQ %s" % tqId ) retVal = self._query( prioSQL % tqId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve winning priority for matching job: %s" % retVal[ 'Message' ] ) if len( retVal[ 'Value' ] ) == 0: continue prio = retVal[ 'Value' ][0][0] retVal = self._query( "%s %s" % ( preJobSQL % ( tqId, prio ), postJobSQL ), conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Can't begin transaction for matching job: %s" % retVal[ 'Message' ] ) jobTQList = [ ( row[0], row[1] ) for row in retVal[ 'Value' ] ] if len( jobTQList ) == 0: gLogger.info( "Task queue %s seems to be empty, triggering a cleaning" % tqId ) self.__deleteTQWithDelay.add( tqId, 300, ( tqId, tqOwnerDN, tqOwnerGroup ) ) while len( jobTQList ) > 0: jobId, tqId = jobTQList.pop( random.randint( 0, len( jobTQList ) - 1 ) ) self.log.info( "Trying to extract job %s from TQ %s" % ( jobId, tqId ) ) retVal = self.deleteJob( jobId, connObj = connObj ) if not retVal[ 'OK' ]: msgFix = "Could not take job" msgVar = " %s out from the TQ %s: %s" % ( jobId, tqId, retVal[ 'Message' ] ) self.log.error( msgFix, msgVar ) return S_ERROR( msgFix + msgVar ) if retVal[ 'Value' ] == True : self.log.info( "Extracted job %s with prio %s from TQ %s" % ( jobId, prio, tqId ) ) return S_OK( { 'matchFound' : True, 'jobId' : jobId, 'taskQueueId' : tqId, 'tqMatch' : tqMatchDict } ) self.log.info( "No jobs could be extracted from TQ %s" % tqId ) self.log.info( "Could not find a match after %s match retries" % self.__maxMatchRetry ) return S_ERROR( "Could not find a match after %s match retries" % self.__maxMatchRetry ) def matchAndGetTaskQueue( self, tqMatchDict, numQueuesToGet = 1, skipMatchDictDef = False, negativeCond = {}, connObj = False ): """ Get a queue that matches the requirements """ #Make a copy to avoid modification of original if escaping needs to be done tqMatchDict = dict( tqMatchDict ) if not skipMatchDictDef: retVal = self._checkMatchDefinition( tqMatchDict ) if not retVal[ 'OK' ]: return retVal retVal = self.__generateTQMatchSQL( tqMatchDict, numQueuesToGet = numQueuesToGet, negativeCond = negativeCond ) if not retVal[ 'OK' ]: return retVal matchSQL = retVal[ 'Value' ] retVal = self._query( matchSQL, conn = connObj ) if not retVal[ 'OK' ]: return retVal return S_OK( [ ( row[0], row[1], row[2] ) for row in retVal[ 'Value' ] ] ) def __generateSQLSubCond( self, sqlString, value, boolOp = 'OR' ): if type( value ) not in ( types.ListType, types.TupleType ): return sqlString % str( value ).strip() sqlORList = [] for v in value: sqlORList.append( sqlString % str( v ).strip() ) return "( %s )" % ( " %s " % boolOp ).join( sqlORList ) def __generateNotSQL( self, tableDict, negativeCond ): """ Generate negative conditions Can be a list of dicts or a dict: - list of dicts will be OR of conditional dicts - dicts will be normal conditional dict ( kay1 in ( v1, v2, ... ) AND key2 in ( v3, v4, ... ) ) """ condType = type( negativeCond ) if condType in ( types.ListType, types.TupleType ): sqlCond = [] for cD in negativeCond: sqlCond.append( self.__generateNotDictSQL( tableDict, cD ) ) return " ( %s )" % " OR ".join( sqlCond ) elif condType == types.DictType: return self.__generateNotDictSQL( tableDict, negativeCond ) raise RuntimeError( "negativeCond has to be either a list or a dict and it's %s" % condType ) def __generateNotDictSQL( self, tableDict, negativeCond ): """ Generate the negative sql condition from a standard condition dict """ condList = [] for field in negativeCond: if field in self.__multiValueMatchFields: fullTableN = '`tq_TQTo%ss`' % field valList = negativeCond[ field ] if type( valList ) not in ( types.TupleType, types.ListType ): valList = ( valList, ) for value in valList: value = self._escapeString( value )[ 'Value' ] sql = "%s NOT IN ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( value, fullTableN, fullTableN, fullTableN ) condList.append( sql ) elif field in self.__singleValueDefFields: for value in negativeCond[field]: value = self._escapeString( value )[ 'Value' ] sql = "%s != tq.%s " % ( value, field ) condList.append( sql ) return "( %s )" % " AND ".join( condList ) def __generateTablesName( self, sqlTables, field ): fullTableName = 'tq_TQTo%ss' % field if fullTableName not in sqlTables: tableN = field.lower() sqlTables[ fullTableName ] = tableN return tableN, "`%s`" % fullTableName, return sqlTables[ fullTableName ], "`%s`" % fullTableName def __generateTQMatchSQL( self, tqMatchDict, numQueuesToGet = 1, negativeCond = {} ): """ Generate the SQL needed to match a task queue """ #Only enabled TQs #sqlCondList = [ "Enabled >= 1" ] sqlCondList = [] sqlTables = { "tq_TaskQueues" : "tq" } #If OwnerDN and OwnerGroup are defined only use those combinations that make sense if 'OwnerDN' in tqMatchDict and 'OwnerGroup' in tqMatchDict: groups = tqMatchDict[ 'OwnerGroup' ] if type( groups ) not in ( types.ListType, types.TupleType ): groups = [ groups ] dns = tqMatchDict[ 'OwnerDN' ] if type( dns ) not in ( types.ListType, types.TupleType ): dns = [ dns ] ownerConds = [] for group in groups: if Properties.JOB_SHARING in CS.getPropertiesForGroup( group.replace( '"', "" ) ): ownerConds.append( "tq.OwnerGroup = %s" % group ) else: for dn in dns: ownerConds.append( "( tq.OwnerDN = %s AND tq.OwnerGroup = %s )" % ( dn, group ) ) sqlCondList.append( " OR ".join( ownerConds ) ) else: #If not both are defined, just add the ones that are defined for field in ( 'OwnerGroup', 'OwnerDN' ): if field in tqMatchDict: sqlCondList.append( self.__generateSQLSubCond( "tq.%s = %%s" % field, tqMatchDict[ field ] ) ) #Type of single value conditions for field in ( 'CPUTime', 'Setup' ): if field in tqMatchDict: if field in ( 'CPUTime' ): sqlCondList.append( self.__generateSQLSubCond( "tq.%s <= %%s" % field, tqMatchDict[ field ] ) ) else: sqlCondList.append( self.__generateSQLSubCond( "tq.%s = %%s" % field, tqMatchDict[ field ] ) ) #Match multi value fields for field in self.__multiValueMatchFields: #It has to be %ss , with an 's' at the end because the columns names # are plural and match options are singular if field in tqMatchDict and tqMatchDict[ field ]: tableN, fullTableN = self.__generateTablesName( sqlTables, field ) sqlMultiCondList = [] if field != 'GridCE' or 'Site' in tqMatchDict: # Jobs for masked sites can be matched if they specified a GridCE # Site is removed from tqMatchDict if the Site is mask. In this case we want # that the GridCE matches explicetly so the COUNT can not be 0. In this case we skip this # condition sqlMultiCondList.append( "( SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = tq.TQId ) = 0" % ( fullTableN, fullTableN, fullTableN ) ) csql = self.__generateSQLSubCond( "%%s IN ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ field ] ) sqlMultiCondList.append( csql ) sqlCondList.append( "( %s )" % " OR ".join( sqlMultiCondList ) ) #In case of Site, check it's not in job banned sites if field in self.__bannedJobMatchFields: fullTableN = '`tq_TQToBanned%ss`' % field csql = self.__generateSQLSubCond( "%%s not in ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ field ], boolOp = 'AND' ) sqlCondList.append( csql ) #Resource banning bannedField = "Banned%s" % field if bannedField in tqMatchDict and tqMatchDict[ bannedField ]: fullTableN = '`tq_TQTo%ss`' % field csql = self.__generateSQLSubCond( "%%s not in ( SELECT %s.Value FROM %s WHERE %s.TQId = tq.TQId )" % ( fullTableN, fullTableN, fullTableN ), tqMatchDict[ bannedField ], boolOp = 'AND' ) sqlCondList.append( csql ) #For certain fields, the require is strict. If it is not in the tqMatchDict, the job cannot require it for field in self.__strictRequireMatchFields: if field in tqMatchDict: continue fullTableN = '`tq_TQTo%ss`' % field sqlCondList.append( "( SELECT COUNT(%s.Value) FROM %s WHERE %s.TQId = tq.TQId ) = 0" % ( fullTableN, fullTableN, fullTableN ) ) # Add extra conditions if negativeCond: sqlCondList.append( self.__generateNotSQL( sqlTables, negativeCond ) ) #Generate the final query string tqSqlCmd = "SELECT tq.TQId, tq.OwnerDN, tq.OwnerGroup FROM `tq_TaskQueues` tq WHERE %s" % ( " AND ".join( sqlCondList ) ) #Apply priorities tqSqlCmd = "%s ORDER BY RAND() / tq.Priority ASC" % tqSqlCmd #Do we want a limit? if numQueuesToGet: tqSqlCmd = "%s LIMIT %s" % ( tqSqlCmd, numQueuesToGet ) return S_OK( tqSqlCmd ) def deleteJob( self, jobId, connObj = False ): """ Delete a job from the task queues Return S_OK( True/False ) / S_ERROR """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't delete job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] retVal = self._query( "SELECT t.TQId, t.OwnerDN, t.OwnerGroup FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE j.JobId = %s AND t.TQId = j.TQId" % jobId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not get job from task queue %s: %s" % ( jobId, retVal[ 'Message' ] ) ) data = retVal[ 'Value' ] if not data: return S_OK( False ) tqId, tqOwnerDN, tqOwnerGroup = data[0] self.log.info( "Deleting job %s" % jobId ) retVal = self._update( "DELETE FROM `tq_Jobs` WHERE JobId = %s" % jobId, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete job from task queue %s: %s" % ( jobId, retVal[ 'Message' ] ) ) result = retVal[ 'Value' ] if retVal[ 'Value' ] == 0: #No job deleted return S_OK( False ) retries = 10 #Always return S_OK() because job has already been taken out from the TQ self.__deleteTQWithDelay.add( tqId, 300, ( tqId, tqOwnerDN, tqOwnerGroup ) ) return S_OK( True ) def getTaskQueueForJob( self, jobId, connObj = False ): """ Return TaskQueue for a given Job Return S_OK( [TaskQueueID] ) / S_ERROR """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't get TQ for job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] retVal = self._query( 'SELECT TQId FROM `tq_Jobs` WHERE JobId = %s ' % jobId, conn = connObj ) if not retVal[ 'OK' ]: return retVal if not retVal['Value']: return S_ERROR( 'Not in TaskQueues' ) return S_OK( retVal['Value'][0][0] ) def getTaskQueueForJobs( self, jobIDs, connObj = False ): """ Return TaskQueues for a given list of Jobs """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't get TQs for a job list: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] jobString = ','.join( [ str( x ) for x in jobIDs ] ) retVal = self._query( 'SELECT JobId,TQId FROM `tq_Jobs` WHERE JobId in (%s) ' % jobString, conn = connObj ) if not retVal[ 'OK' ]: return retVal if not retVal['Value']: return S_ERROR( 'Not in TaskQueues' ) resultDict = {} for jobID, TQID in retVal['Value']: resultDict[int( jobID )] = int( TQID ) return S_OK( resultDict ) def __getOwnerForTaskQueue( self, tqId, connObj = False ): retVal = self._query( "SELECT OwnerDN, OwnerGroup from `tq_TaskQueues` WHERE TQId=%s" % tqId, conn = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if len( data ) == 0: return S_OK( False ) return S_OK( retVal[ 'Value' ][0] ) def __deleteTQIfEmpty( self, args ): ( tqId, tqOwnerDN, tqOwnerGroup ) = args retries = 3 while retries: retries -= 1 result = self.deleteTaskQueueIfEmpty( tqId, tqOwnerDN, tqOwnerGroup ) if result[ 'OK' ]: return gLogger.error( "Could not delete TQ %s: %s" % ( tqId, result[ 'Message' ] ) ) def deleteTaskQueueIfEmpty( self, tqId, tqOwnerDN = False, tqOwnerGroup = False, connObj = False ): """ Try to delete a task queue if its empty """ if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not tqOwnerDN or not tqOwnerGroup: retVal = self.__getOwnerForTaskQueue( tqId, connObj = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if not data: return S_OK( False ) tqOwnerDN, tqOwnerGroup = data sqlCmd = "DELETE FROM `tq_TaskQueues` WHERE Enabled >= 1 AND `tq_TaskQueues`.TQId = %s" % tqId sqlCmd = "%s AND `tq_TaskQueues`.TQId not in ( SELECT DISTINCT TQId from `tq_Jobs` )" % sqlCmd retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) delTQ = retVal[ 'Value' ] if delTQ > 0: for mvField in self.__multiValueDefFields: retVal = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId = %s" % ( mvField, tqId ), conn = connObj ) if not retVal[ 'OK' ]: return retVal self.recalculateTQSharesForEntity( tqOwnerDN, tqOwnerGroup, connObj = connObj ) self.log.info( "Deleted empty and enabled TQ %s" % tqId ) return S_OK( True ) return S_OK( False ) def deleteTaskQueue( self, tqId, tqOwnerDN = False, tqOwnerGroup = False, connObj = False ): """ Try to delete a task queue even if it has jobs """ self.log.info( "Deleting TQ %s" % tqId ) if not connObj: retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] if not tqOwnerDN or not tqOwnerGroup: retVal = self.__getOwnerForTaskQueue( tqId, connObj = connObj ) if not retVal[ 'OK' ]: return retVal data = retVal[ 'Value' ] if not data: return S_OK( False ) tqOwnerDN, tqOwnerGroup = data sqlCmd = "DELETE FROM `tq_TaskQueues` WHERE `tq_TaskQueues`.TQId = %s" % tqId retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) delTQ = retVal[ 'Value' ] sqlCmd = "DELETE FROM `tq_Jobs` WHERE `tq_Jobs`.TQId = %s" % tqId retVal = self._update( sqlCmd, conn = connObj ) if not retVal[ 'OK' ]: return S_ERROR( "Could not delete task queue %s: %s" % ( tqId, retVal[ 'Message' ] ) ) for mvField in self.__multiValueDefFields: retVal = self._update( "DELETE FROM `tq_TQTo%s` WHERE TQId = %s" % tqId, conn = connObj ) if not retVal[ 'OK' ]: return retVal if delTQ > 0: self.recalculateTQSharesForEntity( tqOwnerDN, tqOwnerGroup, connObj = connObj ) return S_OK( True ) return S_OK( False ) def getMatchingTaskQueues( self, tqMatchDict, negativeCond = False ): """ rename to have the same method as exposed in the Matcher """ return self.retrieveTaskQueuesThatMatch( tqMatchDict, negativeCond = negativeCond ) def getNumTaskQueues( self ): """ Get the number of task queues in the system """ sqlCmd = "SELECT COUNT( TQId ) FROM `tq_TaskQueues`" retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return retVal return S_OK( retVal[ 'Value' ][0][0] ) def retrieveTaskQueuesThatMatch( self, tqMatchDict, negativeCond = False ): """ Get the info of the task queues that match a resource """ result = self.matchAndGetTaskQueue( tqMatchDict, numQueuesToGet = 0, negativeCond = negativeCond ) if not result[ 'OK' ]: return result return self.retrieveTaskQueues( [ tqTuple[0] for tqTuple in result[ 'Value' ] ] ) def retrieveTaskQueues( self, tqIdList = False ): """ Get all the task queues """ sqlSelectEntries = [ "`tq_TaskQueues`.TQId", "`tq_TaskQueues`.Priority", "COUNT( `tq_Jobs`.TQId )" ] sqlGroupEntries = [ "`tq_TaskQueues`.TQId", "`tq_TaskQueues`.Priority" ] for field in self.__singleValueDefFields: sqlSelectEntries.append( "`tq_TaskQueues`.%s" % field ) sqlGroupEntries.append( "`tq_TaskQueues`.%s" % field ) sqlCmd = "SELECT %s FROM `tq_TaskQueues`, `tq_Jobs`" % ", ".join( sqlSelectEntries ) sqlTQCond = "AND Enabled >= 1" if tqIdList != False: if len( tqIdList ) == 0: return S_OK( {} ) else: sqlTQCond += " AND `tq_TaskQueues`.TQId in ( %s )" % ", ".join( [ str( id ) for id in tqIdList ] ) sqlCmd = "%s WHERE `tq_TaskQueues`.TQId = `tq_Jobs`.TQId %s GROUP BY %s" % ( sqlCmd, sqlTQCond, ", ".join( sqlGroupEntries ) ) retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve task queues info: %s" % retVal[ 'Message' ] ) tqData = {} for record in retVal[ 'Value' ]: tqId = record[0] tqData[ tqId ] = { 'Priority' : record[1], 'Jobs' : record[2] } record = record[3:] for iP in range( len( self.__singleValueDefFields ) ): tqData[ tqId ][ self.__singleValueDefFields[ iP ] ] = record[ iP ] tqNeedCleaning = False for field in self.__multiValueDefFields: table = "`tq_TQTo%s`" % field sqlCmd = "SELECT %s.TQId, %s.Value FROM %s" % ( table, table, table ) retVal = self._query( sqlCmd ) if not retVal[ 'OK' ]: return S_ERROR( "Can't retrieve task queues field % info: %s" % ( field, retVal[ 'Message' ] ) ) for record in retVal[ 'Value' ]: tqId = record[0] value = record[1] if not tqId in tqData: if tqIdList == False or tqId in tqIdList: self.log.warn( "Task Queue %s is defined in field %s but does not exist, triggering a cleaning" % ( tqId, field ) ) tqNeedCleaning = True else: if field not in tqData[ tqId ]: tqData[ tqId ][ field ] = [] tqData[ tqId ][ field ].append( value ) if tqNeedCleaning: self.cleanOrphanedTaskQueues() return S_OK( tqData ) def __updateGlobalShares( self ): """ Update internal structure for shares """ #Update group shares self.__groupShares = self.getGroupShares() #Apply corrections if enabled if self.isSharesCorrectionEnabled(): result = self.getGroupsInTQs() if not result[ 'OK' ]: self.log.error( "Could not get groups in the TQs", result[ 'Message' ] ) activeGroups = result[ 'Value' ] newShares = {} for group in activeGroups: if group in self.__groupShares: newShares[ group ] = self.__groupShares[ group ] newShares = self.__sharesCorrector.correctShares( newShares ) for group in self.__groupShares: if group in newShares: self.__groupShares[ group ] = newShares[ group ] def recalculateTQSharesForAll( self ): """ Recalculate all priorities for TQ's """ if self.isSharesCorrectionEnabled(): self.log.info( "Updating correctors state" ) self.__sharesCorrector.update() self.__updateGlobalShares() self.log.info( "Recalculating shares for all TQs" ) retVal = self._getConnection() if not retVal[ 'OK' ]: return S_ERROR( "Can't insert job: %s" % retVal[ 'Message' ] ) connObj = retVal[ 'Value' ] result = self._query( "SELECT DISTINCT( OwnerGroup ) FROM `tq_TaskQueues`" ) if not result[ 'OK' ]: return result for group in [ r[0] for r in result[ 'Value' ] ]: self.recalculateTQSharesForEntity( "all", group ) return S_OK() def recalculateTQSharesForEntity( self, userDN, userGroup, connObj = False ): """ Recalculate the shares for a userDN/userGroup combo """ self.log.info( "Recalculating shares for %s@%s TQs" % ( userDN, userGroup ) ) if userGroup in self.__groupShares: share = self.__groupShares[ userGroup ] else: share = float( DEFAULT_GROUP_SHARE ) if Properties.JOB_SHARING in CS.getPropertiesForGroup( userGroup ): #If group has JobSharing just set prio for that entry, userDN is irrelevant return self.__setPrioritiesForEntity( userDN, userGroup, share, connObj = connObj ) selSQL = "SELECT OwnerDN, COUNT(OwnerDN) FROM `tq_TaskQueues` WHERE OwnerGroup='%s' GROUP BY OwnerDN" % ( userGroup ) result = self._query( selSQL, conn = connObj ) if not result[ 'OK' ]: return result #Get owners in this group and the amount of times they appear data = [ ( r[0], r[1] ) for r in result[ 'Value' ] if r ] numOwners = len( data ) #If there are no owners do now if numOwners == 0: return S_OK() #Split the share amongst the number of owners share /= numOwners entitiesShares = dict( [ ( row[0], share ) for row in data ] ) #If corrector is enabled let it work it's magic if self.isSharesCorrectionEnabled(): entitiesShares = self.__sharesCorrector.correctShares( entitiesShares, group = userGroup ) #Keep updating owners = dict( data ) #IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified #(The number of owners didn't change) if userDN in owners and owners[ userDN ] > 1: return self.__setPrioritiesForEntity( userDN, userGroup, entitiesShares[ userDN ], connObj = connObj ) #Oops the number of owners may have changed so we recalculate the prio for all owners in the group for userDN in owners: self.__setPrioritiesForEntity( userDN, userGroup, entitiesShares[ userDN ], connObj = connObj ) return S_OK() def __setPrioritiesForEntity( self, userDN, userGroup, share, connObj = False, consolidationFunc = "AVG" ): """ Set the priority for a userDN/userGroup combo given a splitted share """ self.log.info( "Setting priorities to %s@%s TQs" % ( userDN, userGroup ) ) tqCond = [ "t.OwnerGroup='%s'" % userGroup ] allowBgTQs = gConfig.getValue( "/Registry/Groups/%s/AllowBackgroundTQs" % userGroup, False ) if Properties.JOB_SHARING not in CS.getPropertiesForGroup( userGroup ): tqCond.append( "t.OwnerDN='%s'" % userDN ) tqCond.append( "t.TQId = j.TQId" ) if consolidationFunc == 'AVG': selectSQL = "SELECT j.TQId, SUM( j.RealPriority )/COUNT(j.RealPriority) FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE " elif consolidationFunc == 'SUM': selectSQL = "SELECT j.TQId, SUM( j.RealPriority ) FROM `tq_TaskQueues` t, `tq_Jobs` j WHERE " else: return S_ERROR( "Unknown consolidation func %s for setting priorities" % consolidationFunc ) selectSQL += " AND ".join( tqCond ) selectSQL += " GROUP BY t.TQId" result = self._query( selectSQL, conn = connObj ) if not result[ 'OK' ]: return result tqDict = dict( result[ 'Value' ] ) if len( tqDict ) == 0: return S_OK() #Calculate Sum of priorities totalPrio = 0 for k in tqDict: if tqDict[k] > 0.1 or not allowBgTQs: totalPrio += tqDict[ k ] #Group by priorities prioDict = {} for tqId in tqDict: if tqDict[ tqId ] > 0.1 or not allowBgTQs: prio = ( share / totalPrio ) * tqDict[ tqId ] else: prio = TQ_MIN_SHARE prio = max( prio, TQ_MIN_SHARE ) if prio not in prioDict: prioDict[ prio ] = [] prioDict[ prio ].append( tqId ) #Execute updates for prio in prioDict: tqList = ", ".join( [ str( tqId ) for tqId in prioDict[ prio ] ] ) updateSQL = "UPDATE `tq_TaskQueues` SET Priority=%.4f WHERE TQId in ( %s )" % ( prio, tqList ) self._update( updateSQL, conn = connObj ) return S_OK() def getGroupShares( self ): """ Get all the shares as a DICT """ result = gConfig.getSections( "/Registry/Groups" ) if result[ 'OK' ]: groups = result[ 'Value' ] else: groups = [] shares = {} for group in groups: shares[ group ] = gConfig.getValue( "/Registry/Groups/%s/JobShare" % group, DEFAULT_GROUP_SHARE ) return shares def propagateTQSharesIfChanged( self ): """ If the shares have changed in the CS, recalculate priorities """ shares = self.getGroupShares() if shares == self.__groupShares: return S_OK() self.__groupShares = shares return self.recalculateTQSharesForAll() def modifyJobsPriorities( self, jobPrioDict ): """ Modify the priority for some jobs """ for jId in jobPrioDict: jobPrioDict[jId] = int( jobPrioDict[jId] ) maxJobsInQuery = 1000 jobsList = sorted( jobPrioDict ) prioDict = {} for jId in jobsList: prio = jobPrioDict[ jId ] if not prio in prioDict: prioDict[ prio ] = [] prioDict[ prio ].append( str( jId ) ) updated = 0 for prio in prioDict: jobsList = prioDict[ prio ] for i in range( maxJobsInQuery, 0, len( jobsList ) ): jobs = ",".join( jobsList[ i : i + maxJobsInQuery ] ) updateSQL = "UPDATE `tq_Jobs` SET `Priority`=%s, `RealPriority`=%f WHERE `JobId` in ( %s )" % ( prio, self.__hackJobPriority( prio ), jobs ) result = self._update( updateSQL ) if not result[ 'OK' ]: return result updated += result[ 'Value' ] if not updated: return S_OK() return self.recalculateTQSharesForAll()