def testHandleClientMessageRetransmission(self): """Check that requests get retransmitted but only if there is no status.""" # Make a new fake client client_id = self.SetupClients(1)[0] # Test the standard behavior. base_time = 1000 msgs_recvd = [] default_ttl = rdfvalue.GrrMessage().task_ttl with test_lib.FakeTime(base_time): flow.GRRFlow.StartFlow(client_id=client_id, flow_name="SendingFlow", message_count=1, token=self.token) for i in range(default_ttl): with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)): tasks = self.server.DrainTaskSchedulerQueueForClient( client_id, 100000, rdfvalue.MessageList()) msgs_recvd.append(tasks) # Should return a client message (ttl-1) times and nothing afterwards. self.assertEqual(map(bool, msgs_recvd), [True] * (rdfvalue.GrrMessage().task_ttl - 1) + [False]) # Now we simulate that the workers are overloaded - the client messages # arrive but do not get processed in time. if default_ttl <= 3: self.fail("TTL too low for this test.") msgs_recvd = [] with test_lib.FakeTime(base_time): flow_id = flow.GRRFlow.StartFlow( client_id=client_id, flow_name="SendingFlow", message_count=1, token=self.token) for i in range(default_ttl): if i == 2: self._ScheduleResponseAndStatus(client_id, flow_id) with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)): tasks = self.server.DrainTaskSchedulerQueueForClient( client_id, 100000, rdfvalue.MessageList()) msgs_recvd.append(tasks) if not tasks: # Even if the request has not been leased ttl times yet, # it should be dequeued by now. new_tasks = queue_manager.QueueManager(token=self.token).Query( queue=rdfvalue.ClientURN(client_id).Queue(), limit=1000) self.assertEqual(len(new_tasks), 0) # Should return a client message twice and nothing afterwards. self.assertEqual( map(bool, msgs_recvd), [True] * 2 + [False] * (rdfvalue.GrrMessage().task_ttl - 2))
def DecompressMessageList(self, signed_message_list): """Decompress the message data from signed_message_list. Args: signed_message_list: A SignedMessageList rdfvalue with some data in it. Returns: a MessageList rdfvalue. Raises: DecodingError: If decompression fails. """ compression = signed_message_list.compression if compression == rdfvalue.SignedMessageList.CompressionType.UNCOMPRESSED: data = signed_message_list.message_list elif compression == rdfvalue.SignedMessageList.CompressionType.ZCOMPRESSION: try: data = zlib.decompress(signed_message_list.message_list) except zlib.error as e: raise DecodingError("Failed to decompress: %s" % e) else: raise DecodingError("Compression scheme not supported") try: result = rdfvalue.MessageList(data) except rdfvalue.DecodeError: raise DecodingError("RDFValue parsing failed.") return result
def ClientServerCommunicate(self, timestamp=None): """Tests the end to end encrypted communicators.""" message_list = rdfvalue.MessageList() for i in range(1, 11): message_list.job.Append(session_id=rdfvalue.SessionID( base="aff4:/flows", queue=queues.FLOWS, flow_name=i), name="OMG it's a string") result = rdfvalue.ClientCommunication() timestamp = self.client_communicator.EncodeMessages( message_list, result, timestamp=timestamp) self.cipher_text = result.SerializeToString() (decoded_messages, source, client_timestamp) = (self.server_communicator.DecryptMessage( self.cipher_text)) self.assertEqual(source, self.client_communicator.common_name) self.assertEqual(client_timestamp, timestamp) self.assertEqual(len(decoded_messages), 10) for i in range(1, 11): self.assertEqual( decoded_messages[i - 1].session_id, rdfvalue.SessionID(base="aff4:/flows", queue=queues.FLOWS, flow_name=i)) return decoded_messages
def Drain(self, max_size=1024): """Return a GrrQueue message list from the queue, draining it. This is used to get the messages going _TO_ the server when the client connects. Args: max_size: The size (in bytes) of the returned protobuf will be at most one message length over this size. Returns: A MessageList protobuf """ queue = rdfvalue.MessageList() length = 0 for message in self._out_queue.Get(): queue.job.Append(message) stats.STATS.IncrementCounter("grr_client_sent_messages") length += len(message) if length > max_size: break return queue
def Drain(self, max_size=1024): """Return a GrrQueue message list from the queue, draining it. This is used to get the messages going _TO_ the server when the client connects. Args: max_size: The size of the returned protobuf will be at most one message length over this size. Returns: A MessageList protobuf """ queue = rdfvalue.MessageList() length = 0 self._out_queue.sort(key=lambda msg: msg[0]) # Front pops are quadratic so we reverse the queue. self._out_queue.reverse() # Use implicit True/False evaluation instead of len (WTF) while self._out_queue and length < max_size: message = self._out_queue.pop()[1] queue.job.Append(message) stats.STATS.IncrementCounter("grr_client_sent_messages") # Maintain the output queue tally length += len(message.args) self._out_queue_size -= len(message.args) # Restore the old order. self._out_queue.reverse() return queue
def testErrorDetection(self): """Tests the end to end encrypted communicators.""" # Install the client - now we can verify its signed messages self.MakeClientAFF4Record() # Make something to send message_list = rdfvalue.MessageList() for i in range(0, 10): message_list.job.Append(session_id=str(i)) result = rdfvalue.ClientCommunication() self.client_communicator.EncodeMessages(message_list, result) cipher_text = result.SerializeToString() # Depending on this modification several things may happen: # 1) The padding may not match which will cause a decryption exception. # 2) The protobuf may fail to decode causing a decoding exception. # 3) The modification may affect the signature resulting in UNAUTHENTICATED # messages. # 4) The modification may have no effect on the data at all. for x in range(0, len(cipher_text), 50): # Futz with the cipher text (Make sure it's really changed) mod_cipher_text = (cipher_text[:x] + chr((ord(cipher_text[x]) % 250) + 1) + cipher_text[x + 1:]) try: decoded, client_id, _ = self.server_communicator.DecryptMessage( mod_cipher_text) for i, message in enumerate(decoded): # If the message is actually authenticated it must not be changed! if message.auth_state == message.AuthorizationState.AUTHENTICATED: self.assertEqual(message.source, client_id) # These fields are set by the decoder and are not present in the # original message - so we clear them before comparison. message.auth_state = None message.source = None self.assertProtoEqual(message, message_list.job[i]) else: logging.debug("Message %s: Authstate: %s", i, message.auth_state) except communicator.DecodingError as e: logging.debug("Detected alteration at %s: %s", x, e)
def testDrainUpdateSessionRequestStates(self): """Draining the flow requests and preparing messages.""" # This flow sends 10 messages on Start() flow_obj = self.FlowSetup("SendingTestFlow") session_id = flow_obj.session_id # There should be 10 messages in the client's task queue manager = queue_manager.QueueManager(token=self.token) tasks = manager.Query(self.client_id, 100) self.assertEqual(len(tasks), 10) # Check that the response state objects have the correct ts_id set # in the client_queue: for task in tasks: request_id = task.request_id # Retrieve the request state for this request_id request_state, _ = data_store.DB.Resolve( session_id.Add("state"), manager.FLOW_REQUEST_TEMPLATE % request_id, token=self.token) request_state = rdfvalue.RequestState(request_state) # Check that task_id for the client message is correctly set in # request_state. self.assertEqual(request_state.request.task_id, task.task_id) # Now ask the server to drain the outbound messages into the # message list. response = rdfvalue.MessageList() self.server.DrainTaskSchedulerQueueForClient( self.client_id, 5, response) # Check that we received only as many messages as we asked for self.assertEqual(len(response.job), 5) for i in range(4): self.assertEqual(response.job[i].session_id, session_id) self.assertEqual(response.job[i].name, "Test")
def Drain(self, max_size=1024): """Return a GrrQueue message list from the queue, draining it. This is used to get the messages going _TO_ the server when the client connects. Args: max_size: The size of the returned protobuf will be at most one message length over this size. Returns: A MessageList protobuf """ queue = rdfvalue.MessageList() length = 0 self._out_queue.sort(key=lambda msg: msg[0]) # Front pops are quadratic so we reverse the queue. self._out_queue.reverse() while self._out_queue and length < max_size: message = self._out_queue.pop()[1] queue.job.Append(message) stats.STATS.IncrementCounter("grr_client_sent_messages") # We deliberately look at the serialized length as bytes here. message_length = len(message.Get("args")) # Maintain the output queue tally length += message_length self._out_queue_size -= message_length # Restore the old order. self._out_queue.reverse() return queue
def UrlMock(self, req, num_messages=10, **kwargs): """A mock for url handler processing from the server's POV.""" if "server.pem" in req.get_full_url(): return StringIO.StringIO(config_lib.CONFIG["Frontend.certificate"]) _ = kwargs try: self.client_communication = rdfvalue.ClientCommunication(req.data) # Decrypt incoming messages self.messages, source, ts = self.server_communicator.DecodeMessages( self.client_communication) # Make sure the messages are correct self.assertEqual(source, self.client_cn) for i, message in enumerate(self.messages): # Do not check any status messages. if message.request_id: self.assertEqual(message.response_id, i) self.assertEqual(message.request_id, 1) self.assertEqual(message.session_id, "aff4:/W:session") # Now prepare a response response_comms = rdfvalue.ClientCommunication() message_list = rdfvalue.MessageList() for i in range(0, num_messages): message_list.job.Append(request_id=i, **self.server_response) # Preserve the timestamp as a nonce self.server_communicator.EncodeMessages( message_list, response_comms, destination=source, timestamp=ts, api_version=self.client_communication.api_version) return StringIO.StringIO(response_comms.SerializeToString()) except communicator.RekeyError: raise urllib2.HTTPError(url=None, code=400, msg=None, hdrs=None, fp=None) except communicator.UnknownClientCert: raise urllib2.HTTPError(url=None, code=406, msg=None, hdrs=None, fp=None) except Exception as e: logging.info("Exception in mock urllib2.Open: %s.", e) self.last_urlmock_error = e if flags.FLAGS.debug: pdb.post_mortem() raise urllib2.HTTPError(url=None, code=500, msg=None, hdrs=None, fp=None)
def RunOnce(self): """Makes a single request to the GRR server. Returns: A Status() object indicating how the last POST went. """ try: status = Status() # Here we only drain messages if we were able to connect to the server in # the last poll request. Otherwise we just wait until the connection comes # back so we don't expire our messages too fast. if self.consecutive_connection_errors == 0: # Grab some messages to send message_list = self.client_worker.Drain( max_size=config_lib.CONFIG["Client.max_post_size"]) else: message_list = rdfvalue.MessageList() sent_count = 0 sent = {} require_fastpoll = False for message in message_list.job: sent_count += 1 require_fastpoll |= message.require_fastpoll sent.setdefault(message.priority, 0) sent[message.priority] += 1 status = Status(sent_count=sent_count, sent=sent, require_fastpoll=require_fastpoll) # Make new encrypted ClientCommunication rdfvalue. payload = rdfvalue.ClientCommunication() # If our memory footprint is too large, we advertise that our input queue # is full. This will prevent the server from sending us any messages, and # hopefully allow us to work down our memory usage, by processing any # outstanding messages. if self.client_worker.MemoryExceeded(): logging.info("Memory exceeded, will not retrieve jobs.") payload.queue_size = 1000000 else: # Let the server know how many messages are currently queued in # the input queue. payload.queue_size = self.client_worker.InQueueSize() nonce = self.communicator.EncodeMessages(message_list, payload) response = self.MakeRequest(payload.SerializeToString(), status) if status.code != 200: # We don't print response here since it should be encrypted and will # cause ascii conversion errors. logging.info( "%s: Could not connect to server at %s, status %s", self.communicator.common_name, self.GetServerUrl(), status.code) # Reschedule the tasks back on the queue so they get retried next time. messages = list(message_list.job) for message in messages: message.priority = rdfvalue.GrrMessage.Priority.HIGH_PRIORITY message.require_fastpoll = False message.ttl -= 1 if message.ttl > 0: # Schedule with high priority to make it jump the queue. self.client_worker.QueueResponse( message, rdfvalue.GrrMessage.Priority.HIGH_PRIORITY + 1) else: logging.info("Dropped message due to retransmissions.") return status if not response: return status try: tmp = self.communicator.DecryptMessage(response) (messages, source, server_nonce) = tmp if server_nonce != nonce: logging.info("Nonce not matched.") status.code = 500 return status except proto2_message.DecodeError: logging.info("Protobuf decode error. Bad URL or auth.") status.code = 500 return status if source != self.communicator.server_name: logging.info( "Received a message not from the server " "%s, expected %s.", source, self.communicator.server_name) status.code = 500 return status status.received_count = len(messages) # If we're not going to fastpoll based on outbound messages, check to see # if any inbound messages want us to fastpoll. This means we drop to # fastpoll immediately on a new request rather than waiting for the next # beacon to report results. if not status.require_fastpoll: for message in messages: if message.require_fastpoll: status.require_fastpoll = True break # Process all messages. Messages can be processed by clients in # any order since clients do not have state. self.client_worker.QueueMessages(messages) except Exception: # pylint: disable=broad-except # Catch everything, yes, this is terrible but necessary logging.warn("Uncaught exception caught: %s", traceback.format_exc()) if status: status.code = 500 if flags.FLAGS.debug: pdb.post_mortem() return status