def testReadFlowForProcessingUpdatesFlowObjects(self): client_id, flow_id = self._SetupClientAndFlow() now = rdfvalue.RDFDatetime.Now() processing_time = rdfvalue.Duration("60s") processing_deadline = now + processing_time with test_lib.FakeTime(now): flow_for_processing = self.db.ReadFlowForProcessing( client_id, flow_id, processing_time) self.assertEqual(flow_for_processing.processing_on, utils.ProcessIdString()) self.assertEqual(flow_for_processing.processing_since, now) self.assertEqual(flow_for_processing.processing_deadline, processing_deadline) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertEqual(read_flow.processing_on, utils.ProcessIdString()) self.assertEqual(read_flow.processing_since, now) self.assertEqual(read_flow.processing_deadline, processing_deadline) flow_for_processing.next_request_to_process = 5 self.db.ReturnProcessedFlow(flow_for_processing) self.assertFalse(flow_for_processing.processing_on) self.assertIsNone(flow_for_processing.processing_since) self.assertIsNone(flow_for_processing.processing_deadline) read_flow = self.db.ReadFlowObject(client_id, flow_id) self.assertFalse(read_flow.processing_on) self.assertIsNone(read_flow.processing_since) self.assertIsNone(read_flow.processing_deadline) self.assertEqual(read_flow.next_request_to_process, 5)
def LeaseClientMessages(self, client_id, lease_time=None, limit=sys.maxsize): """Leases available client messages for the client with the given id.""" leased_messages = [] now = rdfvalue.RDFDatetime.Now() expiration_time = now + lease_time process_id_str = utils.ProcessIdString() leases = self.client_message_leases for msgs_by_id in itervalues(self.client_messages): for msg in sorted(itervalues(msgs_by_id), key=lambda m: m.task_id): if db_utils.ClientIdFromGrrMessage(msg) != client_id: continue existing_lease = leases.get(msg.task_id) if not existing_lease or existing_lease[0] < now: leases[msg.task_id] = (expiration_time, process_id_str) msg.leased_until = expiration_time msg.leased_by = process_id_str leased_messages.append(msg) if len(leased_messages) >= limit: break return leased_messages
def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None): """Leases a number of message handler requests up to the indicated limit.""" now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) expiry = now + lease_time expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry) query = ("UPDATE message_handler_requests " "SET leased_until=%s, leased_by=%s " "WHERE leased_until IS NULL OR leased_until < %s " "LIMIT %s") id_str = utils.ProcessIdString() args = (expiry_str, id_str, now_str, limit) updated = cursor.execute(query, args) if updated == 0: return [] cursor.execute( "SELECT timestamp, request FROM message_handler_requests " "WHERE leased_by=%s AND leased_until=%s LIMIT %s", (id_str, expiry_str, updated)) res = [] for timestamp, request in cursor.fetchall(): req = rdf_objects.MessageHandlerRequest.FromSerializedString(request) req.timestamp = mysql_utils.MysqlToRDFDatetime(timestamp) req.leased_until = expiry req.leased_by = id_str res.append(req) return res
def ReadFlowForProcessing(self, client_id, flow_id, processing_time): """Marks a flow as being processed on this worker and returns it.""" rdf_flow = self.ReadFlowObject(client_id, flow_id) # TODO(user): remove the check for a legacy hunt prefix as soon as # AFF4 is gone. if rdf_flow.parent_hunt_id and not rdf_flow.parent_hunt_id.startswith("H:"): rdf_hunt = self.ReadHuntObject(rdf_flow.parent_hunt_id) if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing( rdf_hunt.hunt_state): raise db.ParentHuntIsNotRunningError( client_id, flow_id, rdf_hunt.hunt_id, rdf_hunt.hunt_state) now = rdfvalue.RDFDatetime.Now() if rdf_flow.processing_on and rdf_flow.processing_deadline > now: raise ValueError("Flow %s on client %s is already being processed." % (client_id, flow_id)) processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() self.UpdateFlow( client_id, flow_id, processing_on=process_id_string, processing_since=now, processing_deadline=processing_deadline) rdf_flow.processing_on = process_id_string rdf_flow.processing_since = now rdf_flow.processing_deadline = processing_deadline return rdf_flow
def _LeaseFlowProcessingReqests(self, cursor=None): """Leases a number of flow processing requests.""" now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) expiry = now + rdfvalue.Duration("10m") expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry) query = ("UPDATE flow_processing_requests " "SET leased_until=%s, leased_by=%s " "WHERE (delivery_time IS NULL OR delivery_time <= %s) AND " "(leased_until IS NULL OR leased_until < %s) " "LIMIT %s") id_str = utils.ProcessIdString() args = (expiry_str, id_str, now_str, now_str, 50) updated = cursor.execute(query, args) if updated == 0: return [] cursor.execute( "SELECT timestamp, request FROM flow_processing_requests " "WHERE leased_by=%s AND leased_until=%s LIMIT %s", (id_str, expiry_str, updated)) res = [] for timestamp, request in cursor.fetchall(): req = rdf_flows.FlowProcessingRequest.FromSerializedString(request) req.timestamp = mysql_utils.MysqlToRDFDatetime(timestamp) req.leased_until = expiry req.leased_by = id_str res.append(req) return res
def LeaseCronJobs(self, cronjob_ids=None, lease_time=None, cursor=None): """Leases all available cron jobs.""" now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToTimestamp(now) expiry_str = mysql_utils.RDFDatetimeToTimestamp(now + lease_time) id_str = utils.ProcessIdString() query = ( "UPDATE cron_jobs " "SET leased_until=FROM_UNIXTIME(%s), leased_by=%s " "WHERE (leased_until IS NULL OR leased_until < FROM_UNIXTIME(%s))") args = [expiry_str, id_str, now_str] if cronjob_ids: query += " AND job_id in (%s)" % ", ".join( ["%s"] * len(cronjob_ids)) args += cronjob_ids updated = cursor.execute(query, args) if updated == 0: return [] cursor.execute( "SELECT job, UNIX_TIMESTAMP(create_time), enabled," "forced_run_requested, last_run_status, UNIX_TIMESTAMP(last_run_time), " "current_run_id, state, UNIX_TIMESTAMP(leased_until), leased_by " "FROM cron_jobs " "FORCE INDEX (cron_jobs_by_lease) " "WHERE leased_until=FROM_UNIXTIME(%s) AND leased_by=%s", [expiry_str, id_str]) return [self._CronJobFromRow(row) for row in cursor.fetchall()]
def LeaseFlowForProcessing(self, client_id, flow_id, processing_time): """Marks a flow as being processed on this worker and returns it.""" rdf_flow = self.ReadFlowObject(client_id, flow_id) if rdf_flow.parent_hunt_id: rdf_hunt = self.ReadHuntObject(rdf_flow.parent_hunt_id) if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing( rdf_hunt.hunt_state): raise db.ParentHuntIsNotRunningError(client_id, flow_id, rdf_hunt.hunt_id, rdf_hunt.hunt_state) now = rdfvalue.RDFDatetime.Now() if rdf_flow.processing_on and rdf_flow.processing_deadline > now: raise ValueError("Flow %s on client %s is already being processed." % (flow_id, client_id)) processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() self.UpdateFlow( client_id, flow_id, processing_on=process_id_string, processing_since=now, processing_deadline=processing_deadline) rdf_flow.processing_on = process_id_string rdf_flow.processing_since = now rdf_flow.processing_deadline = processing_deadline return rdf_flow
def LeaseClientMessages(self, client_id, lease_time=None, limit=sys.maxsize): """Leases available client messages for the client with the given id.""" leased_messages = [] now = rdfvalue.RDFDatetime.Now() expiration_time = now + lease_time process_id_str = utils.ProcessIdString() leases = self.client_message_leases for msgs_by_id in itervalues(self.client_messages): for msg in sorted(itervalues(msgs_by_id), key=lambda m: m.task_id): if db_utils.ClientIdFromGrrMessage(msg) != client_id: continue existing_lease = leases.get(msg.task_id) if not existing_lease or existing_lease[0] < now: if existing_lease: lease_count = existing_lease[-1] + 1 # >= comparison since this check happens before the lease. if lease_count >= db.Database.CLIENT_MESSAGES_TTL: self._DeleteClientMessage(client_id, msg.task_id) continue leases[msg.task_id] = (expiration_time, process_id_str, lease_count) else: leases[msg.task_id] = (expiration_time, process_id_str, 0) msg.leased_until = expiration_time msg.leased_by = process_id_str leased_messages.append(msg) if len(leased_messages) >= limit: break return leased_messages
def LeaseClientMessages(self, client_id, lease_time=None, limit=None, cursor=None): """Leases available client messages for the client with the given id.""" now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) expiry = now + lease_time expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry) proc_id_str = utils.ProcessIdString() client_id_int = mysql_utils.ClientIDToInt(client_id) query = ("UPDATE client_messages " "SET leased_until=%s, leased_by=%s, leased_count=leased_count+1 " "WHERE client_id=%s AND " "(leased_until IS NULL OR leased_until < %s) " "LIMIT %s") args = [expiry_str, proc_id_str, client_id_int, now_str, limit] num_leased = cursor.execute(query, args) if num_leased == 0: return [] query = ("SELECT message, leased_count FROM client_messages " "WHERE client_id=%s AND leased_until=%s AND leased_by=%s") cursor.execute(query, [client_id_int, expiry_str, proc_id_str]) ret = [] expired = [] for msg, leased_count in cursor.fetchall(): message = rdf_flows.GrrMessage.FromSerializedString(msg) message.leased_by = proc_id_str message.leased_until = expiry message.ttl = db.Database.CLIENT_MESSAGES_TTL - leased_count # > comparison since this check happens after the lease. if leased_count > db.Database.CLIENT_MESSAGES_TTL: expired.append((client_id, message.task_id)) else: ret.append(message) if expired: self._DeleteClientMessages(expired, cursor=cursor) return sorted(ret, key=lambda msg: msg.task_id)
def ReadFlowForProcessing(self, client_id, flow_id, processing_time, cursor=None): """Marks a flow as being processed on this worker and returns it.""" query = ("SELECT " + self.FLOW_DB_FIELDS + "FROM flows WHERE client_id=%s AND flow_id=%s") cursor.execute(query, [ mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ]) response = cursor.fetchall() if not response: raise db.UnknownFlowError(client_id, flow_id) row, = response rdf_flow = self._FlowObjectFromRow(row) now = rdfvalue.RDFDatetime.Now() if rdf_flow.processing_on and rdf_flow.processing_deadline > now: raise ValueError( "Flow %s on client %s is already being processed." % (client_id, flow_id)) update_query = ( "UPDATE flows SET processing_on=%s, processing_since=%s, " "processing_deadline=%s WHERE client_id=%s and flow_id=%s") processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() args = [ process_id_string, mysql_utils.RDFDatetimeToMysqlString(now), mysql_utils.RDFDatetimeToMysqlString(processing_deadline), mysql_utils.ClientIDToInt(client_id), mysql_utils.FlowIDToInt(flow_id) ] cursor.execute(update_query, args) # This needs to happen after we are sure that the write has succeeded. rdf_flow.processing_on = process_id_string rdf_flow.processing_since = now rdf_flow.processing_deadline = processing_deadline return rdf_flow
def ReadFlowForProcessing(self, client_id, flow_id, processing_time): """Marks a flow as being processed on this worker and returns it.""" rdf_flow = self.ReadFlowObject(client_id, flow_id) now = rdfvalue.RDFDatetime.Now() if rdf_flow.processing_on and rdf_flow.processing_deadline > now: raise ValueError("Flow %s on client %s is already being processed." % (client_id, flow_id)) processing_deadline = now + processing_time process_id_string = utils.ProcessIdString() self.UpdateFlow( client_id, flow_id, processing_on=process_id_string, processing_since=now, processing_deadline=processing_deadline) rdf_flow.processing_on = process_id_string rdf_flow.processing_since = now rdf_flow.processing_deadline = processing_deadline return rdf_flow
def LeaseCronJobs(self, cronjob_ids=None, lease_time=None): """Leases all available cron jobs.""" leased_jobs = [] now = rdfvalue.RDFDatetime.Now() expiration_time = now + lease_time for job in self.cronjobs.values(): if cronjob_ids and job.cron_job_id not in cronjob_ids: continue existing_lease = self.cronjob_leases.get(job.cron_job_id) if existing_lease is None or existing_lease[0] < now: self.cronjob_leases[job.cron_job_id] = (expiration_time, utils.ProcessIdString()) job = job.Copy() job.leased_until, job.leased_by = self.cronjob_leases[job.cron_job_id] leased_jobs.append(job) return leased_jobs
def _LeaseMessageHandlerRequests(self, lease_time, limit): """Read and lease some outstanding message handler requests.""" leased_requests = [] now = rdfvalue.RDFDatetime.Now() zero = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) expiration_time = now + lease_time leases = self.message_handler_leases for requests in itervalues(self.message_handler_requests): for r in itervalues(requests): existing_lease = leases.get(r.handler_name, {}).get(r.request_id, zero) if existing_lease < now: leases.setdefault(r.handler_name, {})[r.request_id] = expiration_time r.leased_until = expiration_time r.leased_by = utils.ProcessIdString() leased_requests.append(r) if len(leased_requests) >= limit: break return leased_requests
def testMessageHandlerRequestLeasing(self): requests = [ rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000", handler_name="Testhandler", request_id=i * 100, request=rdfvalue.RDFInteger(i)) for i in range(10) ] lease_time = rdfvalue.Duration("5m") with test_lib.FakeTime( rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)): self.db.WriteMessageHandlerRequests(requests) leased = queue.Queue() self.db.RegisterMessageHandler(leased.put, lease_time, limit=5) got = [] while len(got) < 10: try: l = leased.get(True, timeout=6) except queue.Empty: self.fail( "Timed out waiting for messages, expected 10, got %d" % len(got)) self.assertLessEqual(len(l), 5) for m in l: self.assertEqual(m.leased_by, utils.ProcessIdString()) self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now()) self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now()) m.leased_by = None m.leased_until = None m.timestamp = None got += l self.db.DeleteMessageHandlerRequests(got) got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got)
def LeaseClientActionRequests(self, client_id, lease_time=None, limit=sys.maxsize): """Leases available client action requests for a client.""" leased_requests = [] now = rdfvalue.RDFDatetime.Now() expiration_time = now + lease_time process_id_str = utils.ProcessIdString() leases = self.client_action_request_leases # Can't use an iterator here since the dict might change when requests get # deleted. for key, request in sorted(self.client_action_requests.items()): if key[0] != client_id: continue existing_lease = leases.get(key) if not existing_lease or existing_lease[0] < now: if existing_lease: lease_count = existing_lease[-1] + 1 if lease_count > db.Database.CLIENT_MESSAGES_TTL: self._DeleteClientActionRequest(*key) continue else: lease_count = 1 leases[key] = (expiration_time, process_id_str, lease_count) request.leased_until = expiration_time request.leased_by = process_id_str request.ttl = db.Database.CLIENT_MESSAGES_TTL - lease_count leased_requests.append(request) if len(leased_requests) >= limit: break return leased_requests
def LeaseClientMessages(self, client_id, lease_time=None, limit=None, cursor=None): """Leases available client messages for the client with the given id.""" now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) expiry = now + lease_time expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry) proc_id_str = utils.ProcessIdString() client_id_int = mysql_utils.ClientIDToInt(client_id) query = ("UPDATE client_messages " "SET leased_until=%s, leased_by=%s " "WHERE client_id=%s AND " "(leased_until IS NULL OR leased_until < %s) " "LIMIT %s") args = [expiry_str, proc_id_str, client_id_int, now_str, limit] num_leased = cursor.execute(query, args) if num_leased == 0: return [] query = ("SELECT message FROM client_messages " "WHERE client_id=%s AND leased_until=%s AND leased_by=%s") cursor.execute(query, [client_id_int, expiry_str, proc_id_str]) ret = [] for msg, in cursor.fetchall(): message = rdf_flows.GrrMessage.FromSerializedString(msg) message.leased_by = proc_id_str message.leased_until = expiry ret.append(message) return sorted(ret, key=lambda msg: msg.task_id)
def testClientMessageLeasing(self): client_id = self.InitializeClient() messages = [ rdf_flows.GrrMessage(queue=client_id, generate_task_id=True) for _ in range(10) ] lease_time = rdfvalue.Duration("5m") self.db.WriteClientMessages(messages) t0 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000) with test_lib.FakeTime(t0): t0_expiry = t0 + lease_time leased = self.db.LeaseClientMessages(client_id, lease_time=lease_time, limit=5) self.assertEqual(len(leased), 5) for request in leased: self.assertEqual(request.leased_until, t0_expiry) self.assertEqual(request.leased_by, utils.ProcessIdString()) t1 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 100) with test_lib.FakeTime(t1): t1_expiry = t1 + lease_time leased = self.db.LeaseClientMessages(client_id, lease_time=lease_time, limit=5) self.assertEqual(len(leased), 5) for request in leased: self.assertEqual(request.leased_until, t1_expiry) self.assertEqual(request.leased_by, utils.ProcessIdString()) # Nothing left to lease. leased = self.db.LeaseClientMessages(client_id, lease_time=lease_time, limit=2) self.assertEqual(len(leased), 0) read = self.db.ReadClientMessages(client_id) self.assertEqual(len(read), 10) for r in read: self.assertEqual(r.leased_by, utils.ProcessIdString()) self.assertEqual(len([r for r in read if r.leased_until == t0_expiry]), 5) self.assertEqual(len([r for r in read if r.leased_until == t1_expiry]), 5) # Half the leases expired. t2 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 350) with test_lib.FakeTime(t2): leased = self.db.LeaseClientMessages(client_id, lease_time=lease_time) self.assertEqual(len(leased), 5) # All of them expired. t3 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 10350) with test_lib.FakeTime(t3): leased = self.db.LeaseClientMessages(client_id, lease_time=lease_time) self.assertEqual(len(leased), 10)