def OnStartup(self): """A handler that is called on client startup.""" # We read the transaction log and fail any requests that are in it. If there # is anything in the transaction log we assume its there because we crashed # last time and let the server know. last_request = self.transaction_log.Get() if last_request: status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED, error_message="Client killed during transaction") if self.nanny_controller: nanny_status = self.nanny_controller.GetNannyStatus() if nanny_status: status.nanny_status = nanny_status self.SendReply( status, request_id=last_request.request_id, response_id=1, session_id=last_request.session_id, message_type=rdf_flows.GrrMessage.Type.STATUS) self.transaction_log.Clear() # Inform the server that we started. action = admin.SendStartupInfo(grr_worker=self) action.Run(None, ttl=1)
def _SendTerminationMessage(self, status=None): """This notifies the parent flow of our termination.""" if not self.runner_args.request_state.session_id: # No parent flow, nothing to do here. return if status is None: status = rdf_flows.GrrStatus() client_resources = self.context.client_resources user_cpu = client_resources.cpu_usage.user_cpu_time sys_cpu = client_resources.cpu_usage.system_cpu_time status.cpu_time_used.user_cpu_time = user_cpu status.cpu_time_used.system_cpu_time = sys_cpu status.network_bytes_sent = self.context.network_bytes_sent status.child_session_id = self.session_id request_state = self.runner_args.request_state request_state.response_count += 1 # Make a response message msg = rdf_flows.GrrMessage( session_id=request_state.session_id, request_id=request_state.id, response_id=request_state.response_count, auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED, type=rdf_flows.GrrMessage.Type.STATUS, payload=status) # Queue the response now self.queue_manager.QueueResponse(msg) self.QueueNotification(session_id=request_state.session_id)
def GenerateStatusMessage(self, message, response_id, status=None): status = rdf_flows.GrrStatus( status=status or rdf_flows.GrrStatus.ReturnedStatus.OK) if message.name in ["StatFile", "GetFileStat"]: # Create status message to report sample resource usage if self.user_cpu_time is None: status.cpu_time_used.user_cpu_time = self.responses else: status.cpu_time_used.user_cpu_time = self.user_cpu_time if self.system_cpu_time is None: status.cpu_time_used.system_cpu_time = self.responses * 2 else: status.cpu_time_used.system_cpu_time = self.system_cpu_time if self.network_bytes_sent is None: status.network_bytes_sent = self.responses * 3 else: status.network_bytes_sent = self.network_bytes_sent return rdf_flows.GrrMessage(session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, payload=status, type=rdf_flows.GrrMessage.Type.STATUS)
def testCrashReport(self): client_id = "C.1234567890123456" flow_id = "12345678" data_store.REL_DB.WriteClientMetadata(client_id, fleetspeak_enabled=False) self._FlowSetup(client_id, flow_id) # Make sure the event handler is present. self.assertTrue(administrative.ClientCrashHandler) session_id = "%s/%s" % (client_id, flow_id) status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED) messages = [ rdf_flows.GrrMessage( source=client_id, request_id=1, response_id=1, session_id=session_id, payload=status, auth_state="AUTHENTICATED", type=rdf_flows.GrrMessage.Type.STATUS) ] ReceiveMessages(client_id, messages) crash_details_rel = data_store.REL_DB.ReadClientCrashInfo(client_id) self.assertTrue(crash_details_rel) self.assertEqual(crash_details_rel.session_id, session_id)
def testUnauthenticated(self): """What happens if an unauthenticated message is sent to the client? RuntimeError needs to be issued, and the client needs to send a GrrStatus message with the traceback in it. """ # Push a request on it message = rdf_flows.GrrMessage( name="MockAction", session_id=self.session_id, auth_state=rdf_flows.GrrMessage.AuthorizationState.UNAUTHENTICATED, request_id=1, generate_task_id=True) self.context.HandleMessage(message) # We expect to receive an GrrStatus to indicate an exception was # raised: # Check the response - one data and one status message_list = self.context.Drain().job self.assertEqual(len(message_list), 1) self.assertEqual(message_list[0].session_id, self.session_id) self.assertEqual(message_list[0].response_id, 1) status = rdf_flows.GrrStatus(message_list[0].payload) self.assertIn("not Authenticated", status.error_message) self.assertIn("RuntimeError", status.error_message) self.assertNotEqual(status.status, rdf_flows.GrrStatus.ReturnedStatus.OK)
def SendToServer(self): """Schedule some packets from client to server.""" # Generate some client traffic for i in range(0, 10): self.client_communicator.client_worker.SendReply( rdf_flows.GrrStatus(), session_id=rdfvalue.SessionID("W:session"), response_id=i, request_id=1)
def GenerateStatusMessage(self, message, response_id=1, status=None): return rdf_flows.GrrMessage( session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, task_id=message.task_id, payload=rdf_flows.GrrStatus( status=status or rdf_flows.GrrStatus.ReturnedStatus.OK), type=rdf_flows.GrrMessage.Type.STATUS)
def RunAction(self, action_cls, arg=None, grr_worker=None): if arg is None: arg = rdf_flows.GrrMessage() self.results = [] action = self._GetActionInstance(action_cls, grr_worker=grr_worker) action.status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK) action.Run(arg) return self.results
def SendOKStatus(self, response_id, session_id): """Send a message to the flow.""" message = rdf_flows.GrrMessage( request_id=1, response_id=response_id, session_id=session_id, type=rdf_flows.GrrMessage.Type.STATUS, auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED) status = rdf_flows.GrrStatus(status=rdf_flows.GrrStatus.ReturnedStatus.OK) message.payload = status self.SendMessage(message)
def Run(self, unused_arg): """Run the kill.""" # Send a message back to the service to say that we are about to shutdown. reply = rdf_flows.GrrStatus(status=rdf_flows.GrrStatus.ReturnedStatus.OK) # Queue up the response message, jump the queue. self.SendReply(reply, message_type=rdf_flows.GrrMessage.Type.STATUS) # Give the http thread some time to send the reply. self.grr_worker.Sleep(10) # Die ourselves. logging.info("Dying on request.") os._exit(242) # pylint: disable=protected-access
def testEqualTimestampNotifications(self): frontend_server = frontend_lib.FrontEndServer( certificate=config.CONFIG["Frontend.certificate"], private_key=config.CONFIG["PrivateKeys.server_key"], message_expiry_time=100, threadpool_prefix="notification-test") # This schedules 10 requests. session_id = flow.StartFlow( client_id=self.client_id, flow_name="WorkerSendingTestFlow", token=self.token) # We pretend that the client processed all the 10 requests at once and # sends the replies in a single http poll. messages = [ rdf_flows.GrrMessage( request_id=i, response_id=1, session_id=session_id, payload=rdf_protodict.DataBlob(string="test%s" % i), auth_state="AUTHENTICATED", generate_task_id=True) for i in range(1, 11) ] status = rdf_flows.GrrStatus(status=rdf_flows.GrrStatus.ReturnedStatus.OK) statuses = [ rdf_flows.GrrMessage( request_id=i, response_id=2, session_id=session_id, payload=status, type=rdf_flows.GrrMessage.Type.STATUS, auth_state="AUTHENTICATED", generate_task_id=True) for i in range(1, 11) ] frontend_server.ReceiveMessages(self.client_id, messages + statuses) with queue_manager.QueueManager(token=self.token) as q: all_notifications = q.GetNotificationsByPriorityForAllShards( rdfvalue.RDFURN("aff4:/F")) medium_priority = rdf_flows.GrrNotification.Priority.MEDIUM_PRIORITY medium_notifications = all_notifications[medium_priority] my_notifications = [ n for n in medium_notifications if n.session_id == session_id ] # There must not be more than one notification. self.assertEqual(len(my_notifications), 1) notification = my_notifications[0] self.assertEqual(notification.first_queued, notification.timestamp) self.assertEqual(notification.last_status, 10)
def testAnyValueWithoutTypeCallback(self): test_pb = AnyValueWithoutTypeFunctionTest() for value_to_assign in [ rdfvalue.RDFString("test"), rdfvalue.RDFInteger(1234), rdfvalue.RDFBytes(b"abc"), rdf_flows.GrrStatus(status="WORKER_STUCK", error_message="stuck") ]: test_pb.dynamic = value_to_assign serialized = test_pb.SerializeToString() self.assertEqual( AnyValueWithoutTypeFunctionTest.FromSerializedString(serialized), test_pb)
def GenerateStatusMessage(self, message, response_id=1): cpu_time_used = rdf_client_stats.CpuSeconds( user_cpu_time=self.user_cpu_usage.next(), system_cpu_time=self.system_cpu_usage.next()) network_bytes_sent = self.network_usage.next() return rdf_flows.GrrMessage( session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, payload=rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK, cpu_time_used=cpu_time_used, network_bytes_sent=network_bytes_sent), type=rdf_flows.GrrMessage.Type.STATUS)
def Error(self, backtrace, client_id=None, status_code=None): """Terminates this flow with an error.""" try: self.queue_manager.DestroyFlowStates(self.session_id) except queue_manager.MoreDataException: pass if not self.IsRunning(): return # Set an error status reply = rdf_flows.GrrStatus() if status_code is None: reply.status = rdf_flows.GrrStatus.ReturnedStatus.GENERIC_ERROR else: reply.status = status_code client_id = client_id or self.runner_args.client_id if backtrace: reply.error_message = backtrace logging.error("Error in flow %s (%s). Trace: %s", self.session_id, client_id, backtrace) self.context.backtrace = backtrace else: logging.error("Error in flow %s (%s).", self.session_id, client_id) self._SendTerminationMessage(reply) self.context.state = rdf_flow_runner.FlowContext.State.ERROR if self.ShouldSendNotifications(): flow_ref = None if client_id: flow_ref = rdf_objects.FlowReference( client_id=client_id.Basename(), flow_id=self.session_id.Basename()) notification_lib.Notify( self.token.username, rdf_objects.UserNotification.Type.TYPE_FLOW_RUN_FAILED, "Flow (%s) terminated due to error" % self.session_id, rdf_objects.ObjectReference( reference_type=rdf_objects.ObjectReference.Type.FLOW, flow=flow_ref)) self.flow_obj.Flush()
def AsLegacyGrrMessage(self): payload = rdf_flows.GrrStatus(status=inv_status_map[self.status]) if self.error_message: payload.error_message = self.error_message if self.backtrace: payload.backtrace = self.backtrace if self.cpu_time_used: payload.cpu_time_used = self.cpu_time_used if self.network_bytes_sent: payload.network_bytes_sent = self.network_bytes_sent return rdf_flows.GrrMessage(session_id="%s/flows/%s" % (self.client_id, self.flow_id), request_id=self.request_id, response_id=self.response_id, type="STATUS", timestamp=self.timestamp, payload=payload)
def __init__(self, grr_worker=None): """Initializes the action plugin. Args: grr_worker: The grr client worker object which may be used to e.g. send new actions on. """ self.grr_worker = grr_worker self.response_id = INITIAL_RESPONSE_ID self.cpu_used = None self.nanny_controller = None self.status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK) self._last_gc_run = rdfvalue.RDFDatetime.Now() self._gc_frequency = config.CONFIG["Client.gc_frequency"] self.proc = psutil.Process() self.cpu_start = self.proc.cpu_times() self.cpu_limit = rdf_flows.GrrMessage().cpu_limit
def testHandleError(self): """Test handling of a request which raises.""" # Push a request on it message = rdf_flows.GrrMessage( name="RaiseAction", session_id=self.session_id, auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED, request_id=1, generate_task_id=True) self.context.HandleMessage(message) # Check the response - one data and one status message_list = self.context.Drain().job self.assertEqual(message_list[0].session_id, self.session_id) self.assertEqual(message_list[0].response_id, 1) status = rdf_flows.GrrStatus(message_list[0].payload) self.assertIn("RuntimeError", status.error_message) self.assertNotEqual(status.status, rdf_flows.GrrStatus.ReturnedStatus.OK)
def CallState(self, next_state="", start_time=None): """This method is used to schedule a new state on a different worker. This is basically the same as CallFlow() except we are calling ourselves. The state will be invoked in a later time and receive all the messages we send. Args: next_state: The state in this flow to be invoked with the responses. start_time: Start the flow at this time. This Delays notification for flow processing into the future. Note that the flow may still be processed earlier if there are client responses waiting. Raises: FlowRunnerError: if the next state is not valid. """ # Check if the state is valid if not getattr(self.flow_obj, next_state): raise FlowRunnerError("Next state %s is invalid.") # Queue the response message to the parent flow request_state = rdf_flow_runner.RequestState( id=self.GetNextOutboundId(), session_id=self.context.session_id, client_id=self.runner_args.client_id, next_state=next_state) self.QueueRequest(request_state, timestamp=start_time) # Send a fake reply. msg = rdf_flows.GrrMessage( session_id=self.session_id, request_id=request_state.id, response_id=1, auth_state=rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED, payload=rdf_flows.GrrStatus(), type=rdf_flows.GrrMessage.Type.STATUS) self.QueueResponse(msg, start_time) # Notify the worker about it. self.QueueNotification(session_id=self.session_id, timestamp=start_time)
def StatFile(self, args): """StatFile action mock.""" response = rdf_client.StatEntry(pathspec=args.pathspec, st_mode=33184, st_ino=1063090, st_dev=64512, st_nlink=1, st_uid=139592, st_gid=5000, st_size=len(self.data), st_atime=1336469177, st_mtime=1336129892, st_ctime=1336129892) self.responses += 1 self.count += 1 # Create status message to report sample resource usage status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK) if self.user_cpu_time is None: status.cpu_time_used.user_cpu_time = self.responses else: status.cpu_time_used.user_cpu_time = self.user_cpu_time if self.system_cpu_time is None: status.cpu_time_used.system_cpu_time = self.responses * 2 else: status.cpu_time_used.system_cpu_time = self.system_cpu_time if self.network_bytes_sent is None: status.network_bytes_sent = self.responses * 3 else: status.network_bytes_sent = self.network_bytes_sent # Every "failrate" client does not have this file. if self.count == self.failrate: self.count = 0 return [status] return [response, status]
def run(self): """Main thread for processing messages.""" self.OnStartup() try: while True: message = self._in_queue.get() # A message of None is our terminal message. if message is None: break try: self.HandleMessage(message) # Catch any errors and keep going here except Exception as e: # pylint: disable=broad-except logging.warning("%s", e) self.SendReply( rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.GENERIC_ERROR, error_message=utils.SmartUnicode(e)), request_id=message.request_id, response_id=1, session_id=message.session_id, task_id=message.task_id, message_type=rdf_flows.GrrMessage.Type.STATUS) if flags.FLAGS.pdb_post_mortem: pdb.post_mortem() except Exception as e: # pylint: disable=broad-except logging.error("Exception outside of the processing loop: %r", e) finally: # There's no point in running the client if it's broken out of the # processing loop and it should be restarted shortly anyway. logging.fatal("The client has broken out of its processing loop.") # The binary (Python threading library, perhaps) has proven in tests to be # very persistent to termination calls, so we kill it with fire. os.kill(os.getpid(), signal.SIGKILL)
def SendResponse(self, session_id, data, client_id=None, well_known=False, request_id=None): if not isinstance(data, rdfvalue.RDFValue): data = rdf_protodict.DataBlob(string=data) if well_known: request_id, response_id = 0, 12345 else: request_id, response_id = request_id or 1, 1 with queue_manager.QueueManager(token=self.token) as flow_manager: flow_manager.QueueResponse( rdf_flows.GrrMessage( source=client_id, session_id=session_id, payload=data, request_id=request_id, auth_state="AUTHENTICATED", response_id=response_id)) if not well_known: # For normal flows we have to send a status as well. flow_manager.QueueResponse( rdf_flows.GrrMessage( source=client_id, session_id=session_id, payload=rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK), request_id=request_id, response_id=response_id + 1, auth_state="AUTHENTICATED", type=rdf_flows.GrrMessage.Type.STATUS)) flow_manager.QueueNotification( session_id=session_id, last_status=request_id) timestamp = flow_manager.frozen_timestamp return timestamp
def ReceiveMessagesRelationalFlows(self, client_id, messages): """Receives and processes messages for flows stored in the relational db. Args: client_id: The client which sent the messages. messages: A list of GrrMessage RDFValues. """ now = time.time() unprocessed_msgs = [] message_handler_requests = [] dropped_count = 0 for session_id, msgs in iteritems( collection.Group(messages, operator.attrgetter("session_id"))): # Remove and handle messages to WellKnownFlows leftover_msgs = self.HandleWellKnownFlows(msgs) for msg in leftover_msgs: if (msg.auth_state != msg.AuthorizationState.AUTHENTICATED and msg.session_id != self.unauth_allowed_session_id): dropped_count += 1 continue if session_id in queue_manager.session_id_map: message_handler_requests.append( rdf_objects.MessageHandlerRequest( client_id=msg.source.Basename(), handler_name=queue_manager. session_id_map[session_id], request_id=msg.response_id, request=msg.payload)) elif session_id in self.legacy_well_known_session_ids: logging.debug( "Dropping message for legacy well known session id %s", session_id) else: unprocessed_msgs.append(msg) if dropped_count: logging.info("Dropped %d unauthenticated messages for %s", dropped_count, client_id) if unprocessed_msgs: flow_responses = [] for message in unprocessed_msgs: flow_responses.append( rdf_flow_objects.FlowResponseForLegacyResponse(message)) data_store.REL_DB.WriteFlowResponses(flow_responses) for msg in unprocessed_msgs: if msg.type == rdf_flows.GrrMessage.Type.STATUS: stat = rdf_flows.GrrStatus(msg.payload) if stat.status == rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED: # A client crashed while performing an action, fire an event. crash_details = rdf_client.ClientCrash( client_id=client_id, session_id=msg.session_id, backtrace=stat.backtrace, crash_message=stat.error_message, nanny_status=stat.nanny_status, timestamp=rdfvalue.RDFDatetime.Now()) events.Events.PublishEvent("ClientCrash", crash_details, token=self.token) if message_handler_requests: data_store.REL_DB.WriteMessageHandlerRequests( message_handler_requests) logging.debug("Received %s messages from %s in %s sec", len(messages), client_id, time.time() - now)
def ReceiveMessages(self, client_id: str, messages: Iterable[rdf_flows.GrrMessage]): """Receives and processes the messages. For each message we update the request object, and place the response in that request's queue. If the request is complete, we send a message to the worker. Args: client_id: The client which sent the messages. messages: A list of GrrMessage RDFValues. """ now = time.time() unprocessed_msgs = [] worker_message_handler_requests = [] frontend_message_handler_requests = [] dropped_count = 0 msgs_by_session_id = collection.Group(messages, lambda m: m.session_id) for session_id, msgs in msgs_by_session_id.items(): for msg in msgs: if (msg.auth_state != msg.AuthorizationState.AUTHENTICATED and msg.session_id != self.unauth_allowed_session_id): dropped_count += 1 continue session_id_str = str(session_id) if session_id_str in message_handlers.session_id_map: request = rdf_objects.MessageHandlerRequest( client_id=msg.source.Basename(), handler_name=message_handlers. session_id_map[session_id], request_id=msg.response_id or random.UInt32(), request=msg.payload) if request.handler_name in self._SHORTCUT_HANDLERS: frontend_message_handler_requests.append(request) else: worker_message_handler_requests.append(request) elif session_id_str in self.legacy_well_known_session_ids: logging.debug( "Dropping message for legacy well known session id %s", session_id) else: unprocessed_msgs.append(msg) if dropped_count: logging.info("Dropped %d unauthenticated messages for %s", dropped_count, client_id) if unprocessed_msgs: flow_responses = [] for message in unprocessed_msgs: try: flow_responses.append( rdf_flow_objects.FlowResponseForLegacyResponse( message)) except ValueError as e: logging.warning( "Failed to parse legacy FlowResponse:\n%s\n%s", e, message) data_store.REL_DB.WriteFlowResponses(flow_responses) for msg in unprocessed_msgs: if msg.type == rdf_flows.GrrMessage.Type.STATUS: stat = rdf_flows.GrrStatus(msg.payload) if stat.status == rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED: # A client crashed while performing an action, fire an event. crash_details = rdf_client.ClientCrash( client_id=client_id, session_id=msg.session_id, backtrace=stat.backtrace, crash_message=stat.error_message, nanny_status=stat.nanny_status, timestamp=rdfvalue.RDFDatetime.Now()) events.Events.PublishEvent("ClientCrash", crash_details, token=self.token) if worker_message_handler_requests: data_store.REL_DB.WriteMessageHandlerRequests( worker_message_handler_requests) if frontend_message_handler_requests: worker_lib.ProcessMessageHandlerRequests( frontend_message_handler_requests) logging.debug("Received %s messages from %s in %s sec", len(messages), client_id, time.time() - now)
def FromLegacyResponses(cls, request=None, responses=None): """Creates a Responses object from old style flow request and responses.""" res = cls() res.request = request if request: res.request_data = rdf_protodict.Dict(request.data) dropped_responses = [] # The iterator that was returned as part of these responses. This should # be passed back to actions that expect an iterator. res.iterator = None if not responses: return res # This may not be needed if we can assume that responses are # returned in lexical order from the data_store. responses.sort(key=operator.attrgetter("response_id")) if request.HasField("request"): client_action_name = request.request.name action_registry = server_stubs.ClientActionStub.classes if client_action_name not in action_registry: raise RuntimeError("Got unknown client action: %s." % client_action_name) expected_response_classes = action_registry[ client_action_name].out_rdfvalues old_response_id = None # Filter the responses by authorized states for msg in responses: # Check if the message is authenticated correctly. if msg.auth_state != msg.AuthorizationState.AUTHENTICATED: logging.warning( "%s: Messages must be authenticated (Auth state %s)", msg.session_id, msg.auth_state) dropped_responses.append(msg) # Skip this message - it is invalid continue # Handle retransmissions if msg.response_id == old_response_id: continue old_response_id = msg.response_id # Check for iterators if msg.type == msg.Type.ITERATOR: if res.iterator: raise ValueError( "Received multiple iterator messages at once.") res.iterator = rdf_client_action.Iterator(msg.payload) continue # Look for a status message if msg.type == msg.Type.STATUS: # Our status is set to the first status message that we see in # the responses. We ignore all other messages after that. res.status = rdf_flows.GrrStatus(msg.payload) # Check this to see if the call succeeded res.success = res.status.status == res.status.ReturnedStatus.OK # Ignore all other messages break if msg.type == msg.Type.MESSAGE: if request.HasField("request"): # Let's do some verification for requests that came from clients. if not expected_response_classes: raise RuntimeError( "Client action %s does not specify out_rdfvalue." % client_action_name) else: args_rdf_name = msg.args_rdf_name if not args_rdf_name: raise RuntimeError( "Deprecated message format received: " "args_rdf_name is None.") elif args_rdf_name not in [ x.__name__ for x in expected_response_classes ]: raise RuntimeError( "Response type was %s but expected %s for %s." % (args_rdf_name, expected_response_classes, client_action_name)) # Use this message res.responses.append(msg.payload) if res.status is None: # This is a special case of de-synchronized messages. if dropped_responses: logging.error( "De-synchronized messages detected:\n %s", "\n".join( [utils.SmartUnicode(x) for x in dropped_responses])) res.LogFlowState(responses) raise ValueError("No valid Status message.") return res
def testEnums(self): """Check that enums are wrapped in a descriptor class.""" sample = rdf_flows.GrrStatus() self.assertEqual(str(sample.status), "OK")
def Next(self): """Grab tasks for us from the server's queue.""" with queue_manager.QueueManager(token=self.token) as manager: request_tasks = manager.QueryAndOwn(self.client_id.Queue(), limit=1, lease_seconds=10000) request_tasks.extend(self._mock_task_queue) self._mock_task_queue[:] = [] # Clear the referenced list. for message in request_tasks: status = None response_id = 1 # Collect all responses for this message from the client mock try: if hasattr(self.client_mock, "HandleMessage"): responses = self.client_mock.HandleMessage(message) else: self.client_mock.message = message responses = getattr(self.client_mock, message.name)(message.payload) if not responses: responses = [] logging.info( "Called client action %s generating %s responses", message.name, len(responses) + 1) if self.status_message_enforced: status = rdf_flows.GrrStatus() except Exception as e: # pylint: disable=broad-except logging.exception("Error %s occurred in client", e) # Error occurred. responses = [] if self.status_message_enforced: error_message = str(e) status = rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus. GENERIC_ERROR) # Invalid action mock is usually expected. if error_message != "Invalid Action Mock.": status.backtrace = traceback.format_exc() status.error_message = error_message # Now insert those on the flow state queue for response in responses: if isinstance(response, rdf_flows.GrrStatus): msg_type = rdf_flows.GrrMessage.Type.STATUS self.AddResourceUsage(response) response = rdf_flows.GrrMessage( session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, payload=response, type=msg_type) elif isinstance(response, rdf_client.Iterator): msg_type = rdf_flows.GrrMessage.Type.ITERATOR response = rdf_flows.GrrMessage( session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, payload=response, type=msg_type) elif not isinstance(response, rdf_flows.GrrMessage): msg_type = rdf_flows.GrrMessage.Type.MESSAGE response = rdf_flows.GrrMessage( session_id=message.session_id, name=message.name, response_id=response_id, request_id=message.request_id, payload=response, type=msg_type) # Next expected response response_id = response.response_id + 1 self.PushToStateQueue(manager, response) # Status may only be None if the client reported itself as crashed. if status is not None: self.AddResourceUsage(status) self.PushToStateQueue( manager, message, response_id=response_id, payload=status, type=rdf_flows.GrrMessage.Type.STATUS) else: # Status may be None only if status_message_enforced is False. if self.status_message_enforced: raise RuntimeError( "status message can only be None when " "status_message_enforced is False") # Additionally schedule a task for the worker manager.QueueNotification(session_id=message.session_id, priority=message.priority) return len(request_tasks)
def testNoValidStatusRaceIsResolved(self): # This tests for the regression of a long standing race condition we saw # where notifications would trigger the reading of another request that # arrives later but wasn't completely written to the database yet. # Timestamp based notification handling should eliminate this bug. # We need a random flow object for this test. session_id = flow.StartAFF4Flow(client_id=self.client_id, flow_name="WorkerSendingTestFlow", token=self.token) worker_obj = self._TestWorker() manager = queue_manager.QueueManager(token=self.token) manager.DeleteNotification(session_id) manager.Flush() # We have a first request that is complete (request_id 1, response_id 1). self.SendResponse(session_id, "Response 1") # However, we also have request #2 already coming in. The race is that # the queue manager might write the status notification to # session_id/state as "status:00000002" but not the status response # itself yet under session_id/state/request:00000002 request_id = 2 response_id = 1 flow_manager = queue_manager.QueueManager(token=self.token) flow_manager.FreezeTimestamp() flow_manager.QueueResponse( rdf_flows.GrrMessage( source=self.client_id, session_id=session_id, payload=rdf_protodict.DataBlob(string="Response 2"), request_id=request_id, auth_state="AUTHENTICATED", response_id=response_id)) status = rdf_flows.GrrMessage( source=self.client_id, session_id=session_id, payload=rdf_flows.GrrStatus( status=rdf_flows.GrrStatus.ReturnedStatus.OK), request_id=request_id, response_id=response_id + 1, auth_state="AUTHENTICATED", type=rdf_flows.GrrMessage.Type.STATUS) # Now we write half the status information. data_store.DB.StoreRequestsAndResponses(new_responses=[(status, None)]) # We make the race even a bit harder by saying the new notification gets # written right before the old one gets deleted. If we are not careful here, # we delete the new notification as well and the flow becomes stuck. # pylint: disable=invalid-name def WriteNotification(self, arg_session_id, start=None, end=None): if arg_session_id == session_id: flow_manager.QueueNotification(session_id=arg_session_id) flow_manager.Flush() self.DeleteNotification.old_target(self, arg_session_id, start=start, end=end) # pylint: enable=invalid-name with utils.Stubber(queue_manager.QueueManager, "DeleteNotification", WriteNotification): # This should process request 1 but not touch request 2. worker_obj.RunOnce() worker_obj.thread_pool.Join() flow_obj = aff4.FACTORY.Open(session_id, token=self.token) self.assertFalse(flow_obj.context.backtrace) self.assertNotEqual(flow_obj.context.state, rdf_flow_runner.FlowContext.State.ERROR) request_data = data_store.DB.ReadResponsesForRequestId(session_id, 2) request_data.sort(key=lambda msg: msg.response_id) self.assertLen(request_data, 2) # Make sure the status and the original request are still there. self.assertEqual(request_data[0].args_rdf_name, "DataBlob") self.assertEqual(request_data[1].args_rdf_name, "GrrStatus") # But there is nothing for request 1. request_data = data_store.DB.ReadResponsesForRequestId(session_id, 1) self.assertEqual(request_data, []) # The notification for request 2 should have survived. with queue_manager.QueueManager(token=self.token) as manager: notifications = manager.GetNotifications(queues.FLOWS) self.assertLen(notifications, 1) notification = notifications[0] self.assertEqual(notification.session_id, session_id) self.assertEqual(notification.timestamp, flow_manager.frozen_timestamp) self.assertEqual(RESULTS, ["Response 1"]) # The last missing piece of request 2 is the actual status message. flow_manager.QueueResponse(status) flow_manager.Flush() # Now make sure request 2 runs as expected. worker_obj.RunOnce() worker_obj.thread_pool.Join() self.assertEqual(RESULTS, ["Response 1", "Response 2"])
def ReceiveMessages(self, client_id, messages): """Receives and processes the messages from the source. For each message we update the request object, and place the response in that request's queue. If the request is complete, we send a message to the worker. Args: client_id: The client which sent the messages. messages: A list of GrrMessage RDFValues. """ if data_store.RelationalDBEnabled(): return self.ReceiveMessagesRelationalFlows(client_id, messages) now = time.time() with queue_manager.QueueManager(token=self.token) as manager: for session_id, msgs in iteritems( collection.Group(messages, operator.attrgetter("session_id"))): # Remove and handle messages to WellKnownFlows leftover_msgs = self.HandleWellKnownFlows(msgs) unprocessed_msgs = [] for msg in leftover_msgs: if (msg.auth_state == msg.AuthorizationState.AUTHENTICATED or msg.session_id == self.unauth_allowed_session_id): unprocessed_msgs.append(msg) if len(unprocessed_msgs) < len(leftover_msgs): logging.info("Dropped %d unauthenticated messages for %s", len(leftover_msgs) - len(unprocessed_msgs), client_id) if not unprocessed_msgs: continue for msg in unprocessed_msgs: manager.QueueResponse(msg) for msg in unprocessed_msgs: # Messages for well known flows should notify even though they don't # have a status. if msg.request_id == 0: manager.QueueNotification(session_id=msg.session_id) # Those messages are all the same, one notification is enough. break elif msg.type == rdf_flows.GrrMessage.Type.STATUS: # If we receive a status message from the client it means the client # has finished processing this request. We therefore can de-queue it # from the client queue. msg.task_id will raise if the task id is # not set (message originated at the client, there was no request on # the server), so we have to check .HasTaskID() first. if msg.HasTaskID(): manager.DeQueueClientRequest(msg) manager.QueueNotification(session_id=msg.session_id, last_status=msg.request_id) stat = rdf_flows.GrrStatus(msg.payload) if stat.status == rdf_flows.GrrStatus.ReturnedStatus.CLIENT_KILLED: # A client crashed while performing an action, fire an event. crash_details = rdf_client.ClientCrash( client_id=client_id, session_id=session_id, backtrace=stat.backtrace, crash_message=stat.error_message, nanny_status=stat.nanny_status, timestamp=rdfvalue.RDFDatetime.Now()) events.Events.PublishEvent("ClientCrash", crash_details, token=self.token) logging.debug("Received %s messages from %s in %s sec", len(messages), client_id, time.time() - now)
def testEnums(self): """Check that enums are wrapped in a descriptor class.""" sample = rdf_flows.GrrStatus() self.assertEqual(sample.status, rdf_flows.GrrStatus.ReturnedStatus.OK)
def __init__(self, request=None, responses=None): self.status = None # A GrrStatus rdfvalue object. self.success = True self.request = request if request: self.request_data = rdf_protodict.Dict(request.data) self._responses = [] dropped_responses = [] # This is the raw message accessible while going through the iterator self.message = None # The iterator that was returned as part of these responses. This should # be passed back to actions that expect an iterator. self.iterator = None if not responses: return # This may not be needed if we can assume that responses are # returned in lexical order from the data_store. responses.sort(key=operator.attrgetter("response_id")) # Filter the responses by authorized states for msg in responses: # Check if the message is authenticated correctly. if msg.auth_state != msg.AuthorizationState.AUTHENTICATED: logging.warning( "%s: Messages must be authenticated (Auth state %s)", msg.session_id, msg.auth_state) dropped_responses.append(msg) # Skip this message - it is invalid continue # Check for iterators if msg.type == msg.Type.ITERATOR: if self.iterator: raise ValueError( "Received multiple iterator messages at once.") self.iterator = rdf_client.Iterator(msg.payload) continue # Look for a status message if msg.type == msg.Type.STATUS: # Our status is set to the first status message that we see in # the responses. We ignore all other messages after that. self.status = rdf_flows.GrrStatus(msg.payload) # Check this to see if the call succeeded self.success = self.status.status == self.status.ReturnedStatus.OK # Ignore all other messages break # Use this message self._responses.append(msg) if self.status is None: # This is a special case of de-synchronized messages. if dropped_responses: logging.error( "De-synchronized messages detected:\n %s", "\n".join( [utils.SmartUnicode(x) for x in dropped_responses])) self._LogFlowState(responses) raise ValueError("No valid Status message.")