def WriteFlowRequests(self, requests): """Writes a list of flow requests to the database.""" flow_processing_requests = [] for request in requests: if (request.client_id, request.flow_id) not in self.flows: raise db.AtLeastOneUnknownFlowError([(request.client_id, request.flow_id)]) for request in requests: key = (request.client_id, request.flow_id) request_dict = self.flow_requests.setdefault(key, {}) request_dict[request.request_id] = request.Copy() request_dict[ request.request_id].timestamp = rdfvalue.RDFDatetime.Now() if request.needs_processing: flow = self.flows[(request.client_id, request.flow_id)] if flow.next_request_to_process == request.request_id: flow_processing_requests.append( rdf_flows.FlowProcessingRequest( client_id=request.client_id, flow_id=request.flow_id)) if flow_processing_requests: self.WriteFlowProcessingRequests(flow_processing_requests)
def testFlowProcessingRequestsQueue(self): client_id, _ = self._SetupClientAndFlow() flow_ids = [u"1234ABC%d" % i for i in range(5)] queue = Queue.Queue() def Callback(request): self.db.AckFlowProcessingRequests([request]) queue.put(request) self.db.RegisterFlowProcessingHandler(Callback) requests = [] for flow_id in flow_ids: requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)) self.db.WriteFlowProcessingRequests(requests) got = [] while len(got) < 5: try: l = queue.get(True, timeout=6) except Queue.Empty: self.fail( "Timed out waiting for messages, expected 5, got %d" % len(got)) got.append(l) got.sort(key=lambda req: req.flow_id) self.assertEqual(requests, got) self.db.UnregisterFlowProcessingHandler()
def testAcknowledgingFlowProcessingRequestsWorks(self): client_id, _ = self._SetupClientAndFlow() flow_ids = [u"1234ABC%d" % i for i in range(5)] now = rdfvalue.RDFDatetime.Now() delivery_time = now + rdfvalue.Duration("10m") requests = [] for flow_id in flow_ids: requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id, delivery_time=delivery_time)) self.db.WriteFlowProcessingRequests(requests) # We stored 5 FlowProcessingRequests, read them back and check they are all # there. stored_requests = self.db.ReadFlowProcessingRequests() stored_requests.sort(key=lambda r: r.flow_id) self.assertEqual(len(stored_requests), 5) self.assertEqual([r.flow_id for r in stored_requests], flow_ids) # Now we ack requests 1 and 2. There should be three remaining in the db. self.db.AckFlowProcessingRequests(stored_requests[1:3]) stored_requests = self.db.ReadFlowProcessingRequests() self.assertEqual(len(stored_requests), 3) self.assertItemsEqual([r.flow_id for r in stored_requests], [flow_ids[0], flow_ids[3], flow_ids[4]]) # Make sure DeleteAllFlowProcessingRequests removes all requests. self.db.DeleteAllFlowProcessingRequests() self.assertEqual(self.db.ReadFlowProcessingRequests(), []) self.db.UnregisterFlowProcessingHandler()
def testFlowProcessingRequestsQueueWithDelay(self): client_id, flow_id = self._SetupClientAndFlow() queue = Queue.Queue() self.db.RegisterFlowProcessingHandler(queue.put) now = rdfvalue.RDFDatetime.Now() delivery_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch( now.AsSecondsSinceEpoch() + 0.5) requests = [] for i in xrange(5): requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id, request_id=i, delivery_time=delivery_time)) self.db.WriteFlowProcessingRequests(requests) got = [] while len(got) < 5: try: l = queue.get(True, timeout=6) except Queue.Empty: self.fail( "Timed out waiting for messages, expected 5, got %d" % len(got)) got.append(l) self.assertGreater(rdfvalue.RDFDatetime.Now(), l.delivery_time) got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got) self.db.UnregisterFlowProcessingHandler()
def testFlowProcessingRequestsQueue(self): client_id, flow_id = self._SetupClientAndFlow() queue = Queue.Queue() self.db.RegisterFlowProcessingHandler(queue.put) requests = [] for i in xrange(5): requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id, request_id=i)) self.db.WriteFlowProcessingRequests(requests) got = [] while len(got) < 5: try: l = queue.get(True, timeout=6) except Queue.Empty: self.fail( "Timed out waiting for messages, expected 5, got %d" % len(got)) got.append(l) got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got) self.db.UnregisterFlowProcessingHandler()
def WriteFlowRequests(self, requests, cursor=None): """Writes a list of flow requests to the database.""" args = [] templates = [] flow_keys = [] needs_processing = {} now_str = mysql_utils.RDFDatetimeToMysqlString( rdfvalue.RDFDatetime.Now()) for r in requests: if r.needs_processing: needs_processing.setdefault((r.client_id, r.flow_id), []).append(r.request_id) flow_keys.append((r.client_id, r.flow_id)) templates.append("(%s, %s, %s, %s, %s, %s)") args.extend([ mysql_utils.ClientIDToInt(r.client_id), mysql_utils.FlowIDToInt(r.flow_id), r.request_id, r.needs_processing, r.SerializeToString(), now_str ]) if needs_processing: flow_processing_requests = [] nr_conditions = [] nr_args = [] for client_id, flow_id in needs_processing: nr_conditions.append("(client_id=%s AND flow_id=%s)") nr_args.append(mysql_utils.ClientIDToInt(client_id)) nr_args.append(mysql_utils.FlowIDToInt(flow_id)) nr_query = ("SELECT client_id, flow_id, next_request_to_process " "FROM flows WHERE ") nr_query += " OR ".join(nr_conditions) cursor.execute(nr_query, nr_args) db_result = cursor.fetchall() for client_id_int, flow_id_int, next_request_to_process in db_result: client_id = mysql_utils.IntToClientID(client_id_int) flow_id = mysql_utils.IntToFlowID(flow_id_int) if next_request_to_process in needs_processing[(client_id, flow_id)]: flow_processing_requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id)) if flow_processing_requests: self._WriteFlowProcessingRequests(flow_processing_requests, cursor) query = ("INSERT INTO flow_requests " "(client_id, flow_id, request_id, needs_processing, request, " "timestamp) VALUES ") query += ", ".join(templates) try: cursor.execute(query, args) except MySQLdb.IntegrityError as e: raise db.AtLeastOneUnknownFlowError(flow_keys, cause=e)
def testRaisesIfFlowProcessingRequestDoesNotTriggerAnyProcessing(self): with flow_test_lib.TestWorker() as worker: flow_id = flow.StartFlow(flow_cls=CallClientParentFlow, client_id=self.client_id) fpr = rdf_flows.FlowProcessingRequest(client_id=self.client_id, flow_id=flow_id) with self.assertRaises(worker_lib.FlowHasNothingToProcessError): worker.ProcessFlow(fpr)
def WriteFlowResponses(self, responses): """Writes a list of flow responses to the database.""" status_available = set() requests_updated = set() for response in responses: flow_key = (response.client_id, response.flow_id) if flow_key not in self.flows: raise db.UnknownFlowError(response.client_id, response.flow_id) request_dict = self.flow_requests.get(flow_key, {}) if response.request_id not in request_dict: logging.error( "Received response for unknown request %s, %s, %d.", response.client_id, response.flow_id, response.request_id) continue response_dict = self.flow_responses.setdefault(flow_key, {}) response_dict.setdefault( response.request_id, {})[response.response_id] = response.Copy() if isinstance(response, rdf_flow_objects.FlowStatus): status_available.add(response) requests_updated.add( (response.client_id, response.flow_id, response.request_id)) # Every time we get a status we store how many responses are expected. for status in status_available: request_dict = self.flow_requests[(status.client_id, status.flow_id)] request = request_dict[status.request_id] request.nr_responses_expected = status.response_id # And we check for all updated requests if we need to process them. needs_processing = [] for client_id, flow_id, request_id in requests_updated: flow_key = (client_id, flow_id) request_dict = self.flow_requests[flow_key] request = request_dict[request_id] if request.nr_responses_expected and not request.needs_processing: response_dict = self.flow_responses.setdefault(flow_key, {}) responses = response_dict.get(request_id, {}) if len(responses) == request.nr_responses_expected: request.needs_processing = True flow = self.flows[flow_key] if flow.next_request_to_process == request_id: needs_processing.append( rdf_flows.FlowProcessingRequest( client_id=client_id, flow_id=flow_id, request_id=request_id)) if needs_processing: self.WriteFlowProcessingRequests(needs_processing)
def testFlowProcessingRequestDeletion(self): client_id, flow_id = self._SetupClientAndFlow() now = rdfvalue.RDFDatetime.Now() delivery_time = now + rdfvalue.Duration("10m") requests = [] for i in xrange(5): requests.append( rdf_flows.FlowProcessingRequest(client_id=client_id, flow_id=flow_id, request_id=i, delivery_time=delivery_time)) self.db.WriteFlowProcessingRequests(requests) stored_requests = self.db.ReadFlowProcessingRequests() self.assertEqual(len(stored_requests), 5) self.assertItemsEqual([r.request_id for r in stored_requests], [0, 1, 2, 3, 4]) self.db.DeleteFlowProcessingRequests(requests[1:3]) stored_requests = self.db.ReadFlowProcessingRequests() self.assertEqual(len(stored_requests), 3) self.assertItemsEqual([r.request_id for r in stored_requests], [0, 3, 4])
def WriteFlowResponses(self, responses, cursor=None): """Writes a list of flow responses to the database.""" if not responses: return # In addition to just writing responses, this function needs to also # - Update the expected nr of response for each request that received a # status. # - Set the needs_processing flag for all requests that now have all # responses. # - Send FlowProcessingRequests for all flows that are waiting on a request # whose needs_processing flag was just set. # To achieve this, we need to get the next request each flow is waiting for. next_request_by_flow = {} # And the number of responses each affected request is waiting for (if # available). responses_expected_by_request = {} # As well as the ids of the currently available responses for each request. current_responses_by_request = {} # We also store all requests we have in the db so we can discard responses # for unknown requests right away. currently_available_requests = set() self._ReadCurrentFlowInfo( responses, currently_available_requests, next_request_by_flow, responses_expected_by_request, current_responses_by_request, cursor) # For some requests we will need to update the number of expected responses. needs_expected_update = {} # For some we will need to update the needs_processing flag. needs_processing_update = set() # Some completed requests will trigger a flow processing request, we collect # them in: flow_processing_requests = [] task_ids_by_request = {} for r in responses: request_key = (r.client_id, r.flow_id, r.request_id) try: # If this is a response coming from a client, a task_id will be set. We # store it in case the request is complete and we can remove the client # messages. task_ids_by_request[request_key] = r.task_id except AttributeError: pass if not isinstance(r, rdf_flow_objects.FlowStatus): continue current = responses_expected_by_request.get(request_key) if current: logging.error("Got duplicate status message for request %s/%s/%d", r.client_id, r.flow_id, r.request_id) # If there is already responses_expected information, we need to make # sure the current status doesn't disagree. if current != r.response_id: raise ValueError( "Got conflicting status information for request %s: %s" % (request_key, r)) else: needs_expected_update[request_key] = r.response_id responses_expected_by_request[request_key] = r.response_id responses_to_write = [] client_messages_to_delete = [] for r in responses: request_key = (r.client_id, r.flow_id, r.request_id) if request_key not in currently_available_requests: logging.info("Dropping response for unknown request %s/%s/%d", r.client_id, r.flow_id, r.request_id) continue responses_to_write.append(r) current_responses = current_responses_by_request.setdefault( request_key, set()) if r.response_id in current_responses: # We have this response already, nothing further to do. continue current_responses.add(r.response_id) expected_responses = responses_expected_by_request.get(request_key, 0) if len(current_responses) == expected_responses: # This response was the one that was missing, time to set the # needs_processing flag. needs_processing_update.add(request_key) if r.request_id == next_request_by_flow[(r.client_id, r.flow_id)]: # The request that is now ready for processing was also the one the # flow was waiting for. req = rdf_flows.FlowProcessingRequest( client_id=r.client_id, flow_id=r.flow_id) flow_processing_requests.append(req) # Since this request is now complete, we can remove the corresponding # client messages if there are any. task_id = task_ids_by_request.get(request_key, None) if task_id is not None: client_messages_to_delete.append((r.client_id, task_id)) now = rdfvalue.RDFDatetime.Now() now_str = mysql_utils.RDFDatetimeToMysqlString(now) if responses_to_write: self._WriteResponses(responses_to_write, now_str, cursor) self._UpdateRequests(needs_processing_update, needs_expected_update, cursor) if client_messages_to_delete: self._DeleteClientMessages(client_messages_to_delete, cursor) if flow_processing_requests: self._WriteFlowProcessingRequests(flow_processing_requests, cursor)
def WriteFlowResponses(self, responses): """Writes a list of flow responses to the database.""" status_available = set() requests_updated = set() task_ids_by_request = {} for response in responses: flow_key = (response.client_id, response.flow_id) if flow_key not in self.flows: logging.error("Received response for unknown flow %s, %s.", response.client_id, response.flow_id) continue request_dict = self.flow_requests.get(flow_key, {}) if response.request_id not in request_dict: logging.error( "Received response for unknown request %s, %s, %d.", response.client_id, response.flow_id, response.request_id) continue response_dict = self.flow_responses.setdefault(flow_key, {}) clone = response.Copy() clone.timestamp = rdfvalue.RDFDatetime.Now() response_dict.setdefault(response.request_id, {})[response.response_id] = clone if isinstance(response, rdf_flow_objects.FlowStatus): status_available.add(response) request_key = (response.client_id, response.flow_id, response.request_id) requests_updated.add(request_key) try: task_ids_by_request[request_key] = response.task_id except AttributeError: pass # Every time we get a status we store how many responses are expected. for status in status_available: request_dict = self.flow_requests[(status.client_id, status.flow_id)] request = request_dict[status.request_id] request.nr_responses_expected = status.response_id # And we check for all updated requests if we need to process them. needs_processing = [] for client_id, flow_id, request_id in requests_updated: flow_key = (client_id, flow_id) request_dict = self.flow_requests[flow_key] request = request_dict[request_id] if request.nr_responses_expected and not request.needs_processing: response_dict = self.flow_responses.setdefault(flow_key, {}) responses = response_dict.get(request_id, {}) if len(responses) == request.nr_responses_expected: request.needs_processing = True task_id = task_ids_by_request.get( (client_id, flow_id, request_id)) if task_id: self._DeleteClientMessage(client_id, task_id) flow = self.flows[flow_key] if flow.next_request_to_process == request_id: needs_processing.append( rdf_flows.FlowProcessingRequest( client_id=client_id, flow_id=flow_id)) if needs_processing: self.WriteFlowProcessingRequests(needs_processing)