def ClientServerCommunicate(self, timestamp=None): """Tests the end to end encrypted communicators.""" message_list = rdf_flows.MessageList() for i in range(1, 11): message_list.job.Append( session_id=rdfvalue.SessionID( base="aff4:/flows", queue=queues.FLOWS, flow_name=i), name="OMG it's a string") result = rdf_flows.ClientCommunication() timestamp = self.client_communicator.EncodeMessages( message_list, result, timestamp=timestamp) self.cipher_text = result.SerializeToString() (decoded_messages, source, client_timestamp) = ( self.server_communicator.DecryptMessage(self.cipher_text)) self.assertEqual(source, self.client_communicator.common_name) self.assertEqual(client_timestamp, timestamp) self.assertEqual(len(decoded_messages), 10) for i in range(1, 11): self.assertEqual( decoded_messages[i - 1].session_id, rdfvalue.SessionID( base="aff4:/flows", queue=queues.FLOWS, flow_name=i)) return decoded_messages
def Drain(self, max_size=1024): """Return a GrrQueue message list from the queue, draining it. This is used to get the messages going _TO_ the server when the client connects. Args: max_size: The size (in bytes) of the returned protobuf will be at most one message length over this size. Returns: A MessageList protobuf """ queue = rdf_flows.MessageList() length = 0 for message in self._out_queue.Get(): queue.job.Append(rdf_flows.GrrMessage(message)) stats.STATS.IncrementCounter("grr_client_sent_messages") length += len(message) if length > max_size: break return queue
def DecompressMessageList(self, signed_message_list): """Decompress the message data from signed_message_list. Args: signed_message_list: A SignedMessageList rdfvalue with some data in it. Returns: a MessageList rdfvalue. Raises: DecodingError: If decompression fails. """ compression = signed_message_list.compression if compression == rdf_flows.SignedMessageList.CompressionType.UNCOMPRESSED: data = signed_message_list.message_list elif (compression == rdf_flows.SignedMessageList.CompressionType.ZCOMPRESSION): try: data = zlib.decompress(signed_message_list.message_list) except zlib.error as e: raise DecodingError("Failed to decompress: %s" % e) else: raise DecodingError("Compression scheme not supported") try: result = rdf_flows.MessageList(data) except rdfvalue.DecodeError: raise DecodingError("RDFValue parsing failed.") return result
def _SendMessages(self, grr_msgs, priority=fs_common_pb2.Message.MEDIUM): """Sends a block of messages through Fleetspeak.""" message_list = rdf_flows.PackedMessageList() communicator.Communicator.EncodeMessageList( rdf_flows.MessageList(job=grr_msgs), message_list) fs_msg = fs_common_pb2.Message( message_type="MessageList", destination=fs_common_pb2.Address(service_name="GRR"), priority=priority) fs_msg.data.Pack(message_list.AsPrimitiveProto()) try: sent_bytes = self._fs.Send(fs_msg) except (IOError, struct.error) as e: logging.fatal( "Broken local Fleetspeak connection (write end): %r", e, exc_info=True) # The fatal call above doesn't terminate the program. The reasons for # this might include Python threads persistency, or Python logging # mechanisms' inconsistency. os._exit(1) # pylint: disable=protected-access stats.STATS.IncrementCounter("grr_client_sent_bytes", sent_bytes)
def HandleMessageBundles(self, request_comms, response_comms): """Processes a queue of messages as passed from the client. We basically dispatch all the GrrMessages in the queue to the task scheduler for backend processing. We then retrieve from the TS the messages destined for this client. Args: request_comms: A ClientCommunication rdfvalue with messages sent by the client. source should be set to the client CN. response_comms: A ClientCommunication rdfvalue of jobs destined to this client. Returns: tuple of (source, message_count) where message_count is the number of messages received from the client with common name source. """ messages, source, timestamp = self._communicator.DecodeMessages( request_comms) now = time.time() if messages: # Receive messages in line. self.ReceiveMessages(source, messages) # We send the client a maximum of self.max_queue_size messages required_count = max(0, self.max_queue_size - request_comms.queue_size) tasks = [] message_list = rdf_flows.MessageList() # Only give the client messages if we are able to receive them in a # reasonable time. if time.time() - now < 10: tasks = self.DrainTaskSchedulerQueueForClient( source, required_count) message_list.job = tasks # Encode the message_list in the response_comms using the same API version # the client used. try: self._communicator.EncodeMessages( message_list, response_comms, destination=str(source), timestamp=timestamp, api_version=request_comms.api_version) except communicator.UnknownClientCert: # We can not encode messages to the client yet because we do not have the # client certificate - return them to the queue so we can try again later. with data_store.DB.GetMutationPool(token=self.token) as pool: queue_manager.QueueManager(token=self.token).Schedule( tasks, pool) raise return source, len(messages)
def UrlMock(self, req, num_messages=10, **kwargs): """A mock for url handler processing from the server's POV.""" if "server.pem" in req.get_full_url(): return StringIO.StringIO(config_lib.CONFIG["Frontend.certificate"]) _ = kwargs try: self.client_communication = rdf_flows.ClientCommunication(req.data) # Decrypt incoming messages self.messages, source, ts = self.server_communicator.DecodeMessages( self.client_communication) # Make sure the messages are correct self.assertEqual(source, self.client_cn) for i, message in enumerate(self.messages): # Do not check any status messages. if message.request_id: self.assertEqual(message.response_id, i) self.assertEqual(message.request_id, 1) self.assertEqual(message.session_id, "aff4:/W:session") # Now prepare a response response_comms = rdf_flows.ClientCommunication() message_list = rdf_flows.MessageList() for i in range(0, num_messages): message_list.job.Append(request_id=i, **self.server_response) # Preserve the timestamp as a nonce self.server_communicator.EncodeMessages( message_list, response_comms, destination=source, timestamp=ts, api_version=self.client_communication.api_version) return StringIO.StringIO(response_comms.SerializeToString()) except communicator.UnknownClientCert: raise urllib2.HTTPError(url=None, code=406, msg=None, hdrs=None, fp=None) except Exception as e: logging.info("Exception in mock urllib2.Open: %s.", e) self.last_urlmock_error = e if flags.FLAGS.debug: pdb.post_mortem() raise urllib2.HTTPError(url=None, code=500, msg=None, hdrs=None, fp=None)
def testErrorDetection(self): """Tests the end to end encrypted communicators.""" # Install the client - now we can verify its signed messages self.MakeClientAFF4Record() # Make something to send message_list = rdf_flows.MessageList() for i in range(0, 10): message_list.job.Append(session_id=str(i)) result = rdf_flows.ClientCommunication() self.client_communicator.EncodeMessages(message_list, result) cipher_text = result.SerializeToString() # Depending on this modification several things may happen: # 1) The padding may not match which will cause a decryption exception. # 2) The protobuf may fail to decode causing a decoding exception. # 3) The modification may affect the signature resulting in UNAUTHENTICATED # messages. # 4) The modification may have no effect on the data at all. for x in range(0, len(cipher_text), 50): # Futz with the cipher text (Make sure it's really changed) mod_cipher_text = (cipher_text[:x] + chr((ord(cipher_text[x]) % 250) + 1) + cipher_text[x + 1:]) try: decoded, client_id, _ = self.server_communicator.DecryptMessage( mod_cipher_text) for i, message in enumerate(decoded): # If the message is actually authenticated it must not be changed! if message.auth_state == message.AuthorizationState.AUTHENTICATED: self.assertEqual(message.source, client_id) # These fields are set by the decoder and are not present in the # original message - so we clear them before comparison. message.auth_state = None message.source = None self.assertRDFValuesEqual(message, message_list.job[i]) else: logging.debug("Message %s: Authstate: %s", i, message.auth_state) except communicator.DecodingError as e: logging.debug("Detected alteration at %s: %s", x, e)
def testPingIsRecorded(self): service_name = "GRR" fake_service_client = _FakeGRPCServiceClient(service_name) fleetspeak_connector.Reset() fleetspeak_connector.Init(service_client=fake_service_client) fsd = fs_frontend_tool.GRRFSServer() grr_client_nr = 0xab grr_client = self.SetupTestClientObject(grr_client_nr) self.SetupClient(grr_client_nr) messages = [ rdf_flows.GrrMessage(request_id=1, response_id=1, session_id="F:123456", payload=rdfvalue.RDFInteger(1)) ] fs_client_id = "\x10\x00\x00\x00\x00\x00\x00\xab" # fs_client_id should be equivalent to grr_client_id_urn self.assertEqual( fs_client_id, fleetspeak_utils.GRRIDToFleetspeakID(grr_client.client_id)) message_list = rdf_flows.PackedMessageList() communicator.Communicator.EncodeMessageList( rdf_flows.MessageList(job=messages), message_list) fs_message = fs_common_pb2.Message(message_type="MessageList", source=fs_common_pb2.Address( client_id=fs_client_id, service_name=service_name)) fs_message.data.Pack(message_list.AsPrimitiveProto()) fake_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42) with test_lib.FakeTime(fake_time): fsd.Process(fs_message, None) md = data_store.REL_DB.ReadClientMetadata(grr_client.client_id) self.assertEqual(md.ping, fake_time) with aff4.FACTORY.Open(grr_client.client_id) as client: self.assertEqual(client.Get(client.Schema.PING), fake_time)
def _SendMessages(self, grr_msgs, background=False): """Sends a block of messages through Fleetspeak.""" message_list = rdf_flows.PackedMessageList() communicator.Communicator.EncodeMessageList( rdf_flows.MessageList(job=grr_msgs), message_list) fs_msg = fs_common_pb2.Message( message_type="MessageList", destination=fs_common_pb2.Address(service_name="GRR"), background=background) fs_msg.data.Pack(message_list.AsPrimitiveProto()) try: sent_bytes = self._fs.Send(fs_msg) except (IOError, struct.error) as e: logging.critical("Broken local Fleetspeak connection (write end).") raise e stats.STATS.IncrementCounter("grr_client_sent_bytes", sent_bytes)
def testDrainUpdateSessionRequestStates(self): """Draining the flow requests and preparing messages.""" # This flow sends 10 messages on Start() flow_obj = self.FlowSetup("SendingTestFlow") session_id = flow_obj.session_id # There should be 10 messages in the client's task queue manager = queue_manager.QueueManager(token=self.token) tasks = manager.Query(self.client_id, 100) self.assertEqual(len(tasks), 10) requests_by_id = {} for request, _ in data_store.DB.ReadRequestsAndResponses( session_id, token=self.token): requests_by_id[request.id] = request # Check that the response state objects have the correct ts_id set # in the client_queue: for task in tasks: request_id = task.request_id # Retrieve the request state for this request_id request = requests_by_id[request_id] # Check that task_id for the client message is correctly set in # request_state. self.assertEqual(request.request.task_id, task.task_id) # Now ask the server to drain the outbound messages into the # message list. response = rdf_flows.MessageList() response.job = self.server.DrainTaskSchedulerQueueForClient( self.client_id, 5) # Check that we received only as many messages as we asked for self.assertEqual(len(response.job), 5) for i in range(4): self.assertEqual(response.job[i].session_id, session_id) self.assertEqual(response.job[i].name, "Test")
def Drain(self, max_size=1024): """Return a GrrQueue message list from the queue, draining it. This is used to get the messages going _TO_ the server when the client connects. Args: max_size: The size of the returned protobuf will be at most one message length over this size. Returns: A MessageList protobuf """ queue = rdf_flows.MessageList() length = 0 self._out_queue.sort(key=lambda msg: msg[0]) # Front pops are quadratic so we reverse the queue. self._out_queue.reverse() while self._out_queue and length < max_size: message = self._out_queue.pop()[1] queue.job.Append(message) stats.STATS.IncrementCounter("grr_client_sent_messages") # We deliberately look at the serialized length as bytes here. message_length = len(message.Get("args")) # Maintain the output queue tally length += message_length self._out_queue_size -= message_length # Restore the old order. self._out_queue.reverse() return queue
def testReceiveMessageListFleetspeak(self): service_name = "GRR" fake_service_client = _FakeGRPCServiceClient(service_name) fleetspeak_connector.Reset() fleetspeak_connector.Init(service_client=fake_service_client) fsd = fs_frontend_tool.GRRFSServer() grr_client_nr = 0xab grr_client_id_urn = self.SetupClient(grr_client_nr) flow_obj = self.FlowSetup(flow_test_lib.FlowOrderTest.__name__, grr_client_id_urn) num_msgs = 9 session_id = flow_obj.session_id messages = [ rdf_flows.GrrMessage(request_id=1, response_id=i, session_id=session_id, payload=rdfvalue.RDFInteger(i)) for i in xrange(1, num_msgs + 1) ] fs_client_id = "\x10\x00\x00\x00\x00\x00\x00\xab" # fs_client_id should be equivalent to grr_client_id_urn self.assertEqual( fs_client_id, fleetspeak_utils.GRRIDToFleetspeakID(grr_client_id_urn.Basename())) message_list = rdf_flows.PackedMessageList() communicator.Communicator.EncodeMessageList( rdf_flows.MessageList(job=messages), message_list) fs_message = fs_common_pb2.Message(message_type="MessageList", source=fs_common_pb2.Address( client_id=fs_client_id, service_name=service_name)) fs_message.data.Pack(message_list.AsPrimitiveProto()) fsd.Process(fs_message, None) # Make sure the task is still on the client queue manager = queue_manager.QueueManager(token=self.token) tasks_on_client_queue = manager.Query(grr_client_id_urn, 100) self.assertEqual(len(tasks_on_client_queue), 1) want_messages = [message.Copy() for message in messages] for want_message in want_messages: # This is filled in by the frontend as soon as it gets the message. want_message.auth_state = ( rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED) want_message.source = grr_client_id_urn stored_messages = data_store.DB.ReadResponsesForRequestId( session_id, 1) self.assertEqual(len(stored_messages), len(want_messages)) stored_messages.sort(key=lambda m: m.response_id) # Check that messages were stored correctly for stored_message, want_message in itertools.izip( stored_messages, want_messages): stored_message.timestamp = None self.assertRDFValuesEqual(stored_message, want_message)
def testHandleClientMessageRetransmission(self): """Check that requests get retransmitted but only if there is no status.""" # Make a new fake client client_id = self.SetupClients(1)[0] # Test the standard behavior. base_time = 1000 msgs_recvd = [] default_ttl = rdf_flows.GrrMessage().task_ttl with test_lib.FakeTime(base_time): flow.GRRFlow.StartFlow(client_id=client_id, flow_name="SendingFlow", message_count=1, token=self.token) for i in range(default_ttl): with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)): tasks = self.server.DrainTaskSchedulerQueueForClient( client_id, 100000, rdf_flows.MessageList()) msgs_recvd.append(tasks) # Should return a client message (ttl-1) times and nothing afterwards. self.assertEqual(map(bool, msgs_recvd), [True] * (rdf_flows.GrrMessage().task_ttl - 1) + [False]) # Now we simulate that the workers are overloaded - the client messages # arrive but do not get processed in time. if default_ttl <= 3: self.fail("TTL too low for this test.") msgs_recvd = [] with test_lib.FakeTime(base_time): flow_id = flow.GRRFlow.StartFlow(client_id=client_id, flow_name="SendingFlow", message_count=1, token=self.token) for i in range(default_ttl): if i == 2: self._ScheduleResponseAndStatus(client_id, flow_id) with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)): tasks = self.server.DrainTaskSchedulerQueueForClient( client_id, 100000, rdf_flows.MessageList()) msgs_recvd.append(tasks) if not tasks: # Even if the request has not been leased ttl times yet, # it should be dequeued by now. new_tasks = queue_manager.QueueManager( token=self.token).Query( queue=rdf_client.ClientURN(client_id).Queue(), limit=1000) self.assertEqual(len(new_tasks), 0) # Should return a client message twice and nothing afterwards. self.assertEqual(map(bool, msgs_recvd), [True] * 2 + [False] * (rdf_flows.GrrMessage().task_ttl - 2))
def RunOnce(self): """Makes a single request to the GRR server. Returns: A Status() object indicating how the last POST went. """ try: status = Status() # Here we only drain messages if we were able to connect to the server in # the last poll request. Otherwise we just wait until the connection comes # back so we don't expire our messages too fast. if self.consecutive_connection_errors == 0: # Grab some messages to send message_list = self.client_worker.Drain( max_size=config_lib.CONFIG["Client.max_post_size"]) else: message_list = rdf_flows.MessageList() sent_count = 0 sent = {} require_fastpoll = False for message in message_list.job: sent_count += 1 require_fastpoll |= message.require_fastpoll sent.setdefault(message.priority, 0) sent[message.priority] += 1 status = Status(sent_count=sent_count, sent=sent, require_fastpoll=require_fastpoll) # Make new encrypted ClientCommunication rdfvalue. payload = rdf_flows.ClientCommunication() # If our memory footprint is too large, we advertise that our input queue # is full. This will prevent the server from sending us any messages, and # hopefully allow us to work down our memory usage, by processing any # outstanding messages. if self.client_worker.MemoryExceeded(): logging.info("Memory exceeded, will not retrieve jobs.") payload.queue_size = 1000000 else: # Let the server know how many messages are currently queued in # the input queue. payload.queue_size = self.client_worker.InQueueSize() nonce = self.communicator.EncodeMessages(message_list, payload) response = self.MakeRequest(payload.SerializeToString(), status) if status.code != 200: # We don't print response here since it should be encrypted and will # cause ascii conversion errors. logging.info( "%s: Could not connect to server at %s, status %s", self.communicator.common_name, self.GetServerUrl(), status.code) # Reschedule the tasks back on the queue so they get retried next time. messages = list(message_list.job) for message in messages: message.priority = rdf_flows.GrrMessage.Priority.HIGH_PRIORITY message.require_fastpoll = False message.ttl -= 1 if message.ttl > 0: # Schedule with high priority to make it jump the queue. self.client_worker.QueueResponse( message, rdf_flows.GrrMessage.Priority.HIGH_PRIORITY + 1) else: logging.info("Dropped message due to retransmissions.") return status if not response: return status try: tmp = self.communicator.DecryptMessage(response) (messages, source, server_nonce) = tmp if server_nonce != nonce: logging.info("Nonce not matched.") status.code = 500 return status except proto2_message.DecodeError: logging.info("Protobuf decode error. Bad URL or auth.") status.code = 500 return status if source != self.communicator.server_name: logging.info( "Received a message not from the server " "%s, expected %s.", source, self.communicator.server_name) status.code = 500 return status status.received_count = len(messages) # If we're not going to fastpoll based on outbound messages, check to see # if any inbound messages want us to fastpoll. This means we drop to # fastpoll immediately on a new request rather than waiting for the next # beacon to report results. if not status.require_fastpoll: for message in messages: if message.require_fastpoll: status.require_fastpoll = True break # Process all messages. Messages can be processed by clients in # any order since clients do not have state. self.client_worker.QueueMessages(messages) except Exception: # pylint: disable=broad-except # Catch everything, yes, this is terrible but necessary logging.warn("Uncaught exception caught: %s", traceback.format_exc()) if status: status.code = 500 if flags.FLAGS.debug: pdb.post_mortem() return status