Ejemplo n.º 1
0
  def testHandleClientMessageRetransmission(self):
    """Check that requests get retransmitted but only if there is no status."""
    # Make a new fake client
    client_id = self.SetupClients(1)[0]

    # Test the standard behavior.
    base_time = 1000
    msgs_recvd = []

    default_ttl = rdfvalue.GrrMessage().task_ttl
    with test_lib.FakeTime(base_time):
      flow.GRRFlow.StartFlow(client_id=client_id, flow_name="SendingFlow",
                             message_count=1, token=self.token)

    for i in range(default_ttl):
      with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)):

        tasks = self.server.DrainTaskSchedulerQueueForClient(
            client_id, 100000, rdfvalue.MessageList())
        msgs_recvd.append(tasks)

    # Should return a client message (ttl-1) times and nothing afterwards.
    self.assertEqual(map(bool, msgs_recvd),
                     [True] * (rdfvalue.GrrMessage().task_ttl - 1) + [False])

    # Now we simulate that the workers are overloaded - the client messages
    # arrive but do not get processed in time.
    if default_ttl <= 3:
      self.fail("TTL too low for this test.")

    msgs_recvd = []

    with test_lib.FakeTime(base_time):
      flow_id = flow.GRRFlow.StartFlow(
          client_id=client_id, flow_name="SendingFlow",
          message_count=1, token=self.token)

    for i in range(default_ttl):
      if i == 2:
        self._ScheduleResponseAndStatus(client_id, flow_id)

      with test_lib.FakeTime(base_time + i * (self.message_expiry_time + 1)):

        tasks = self.server.DrainTaskSchedulerQueueForClient(
            client_id, 100000, rdfvalue.MessageList())
        msgs_recvd.append(tasks)

        if not tasks:
          # Even if the request has not been leased ttl times yet,
          # it should be dequeued by now.
          new_tasks = queue_manager.QueueManager(token=self.token).Query(
              queue=rdfvalue.ClientURN(client_id).Queue(), limit=1000)
          self.assertEqual(len(new_tasks), 0)

    # Should return a client message twice and nothing afterwards.
    self.assertEqual(
        map(bool, msgs_recvd),
        [True] * 2 + [False] * (rdfvalue.GrrMessage().task_ttl - 2))
Ejemplo n.º 2
0
    def DecompressMessageList(self, signed_message_list):
        """Decompress the message data from signed_message_list.

    Args:
      signed_message_list: A SignedMessageList rdfvalue with some data in it.

    Returns:
      a MessageList rdfvalue.

    Raises:
      DecodingError: If decompression fails.
    """
        compression = signed_message_list.compression
        if compression == rdfvalue.SignedMessageList.CompressionType.UNCOMPRESSED:
            data = signed_message_list.message_list

        elif compression == rdfvalue.SignedMessageList.CompressionType.ZCOMPRESSION:
            try:
                data = zlib.decompress(signed_message_list.message_list)
            except zlib.error as e:
                raise DecodingError("Failed to decompress: %s" % e)
        else:
            raise DecodingError("Compression scheme not supported")

        try:
            result = rdfvalue.MessageList(data)
        except rdfvalue.DecodeError:
            raise DecodingError("RDFValue parsing failed.")

        return result
Ejemplo n.º 3
0
    def ClientServerCommunicate(self, timestamp=None):
        """Tests the end to end encrypted communicators."""
        message_list = rdfvalue.MessageList()
        for i in range(1, 11):
            message_list.job.Append(session_id=rdfvalue.SessionID(
                base="aff4:/flows", queue=queues.FLOWS, flow_name=i),
                                    name="OMG it's a string")

        result = rdfvalue.ClientCommunication()
        timestamp = self.client_communicator.EncodeMessages(
            message_list, result, timestamp=timestamp)
        self.cipher_text = result.SerializeToString()

        (decoded_messages, source,
         client_timestamp) = (self.server_communicator.DecryptMessage(
             self.cipher_text))

        self.assertEqual(source, self.client_communicator.common_name)
        self.assertEqual(client_timestamp, timestamp)
        self.assertEqual(len(decoded_messages), 10)
        for i in range(1, 11):
            self.assertEqual(
                decoded_messages[i - 1].session_id,
                rdfvalue.SessionID(base="aff4:/flows",
                                   queue=queues.FLOWS,
                                   flow_name=i))

        return decoded_messages
Ejemplo n.º 4
0
    def Drain(self, max_size=1024):
        """Return a GrrQueue message list from the queue, draining it.

    This is used to get the messages going _TO_ the server when the
    client connects.

    Args:
       max_size: The size (in bytes) of the returned protobuf will be at most
       one message length over this size.

    Returns:
       A MessageList protobuf
    """
        queue = rdfvalue.MessageList()
        length = 0

        for message in self._out_queue.Get():
            queue.job.Append(message)
            stats.STATS.IncrementCounter("grr_client_sent_messages")
            length += len(message)

            if length > max_size:
                break

        return queue
Ejemplo n.º 5
0
    def Drain(self, max_size=1024):
        """Return a GrrQueue message list from the queue, draining it.

    This is used to get the messages going _TO_ the server when the
    client connects.

    Args:
       max_size: The size of the returned protobuf will be at most one
       message length over this size.

    Returns:
       A MessageList protobuf
    """
        queue = rdfvalue.MessageList()

        length = 0
        self._out_queue.sort(key=lambda msg: msg[0])

        # Front pops are quadratic so we reverse the queue.
        self._out_queue.reverse()

        # Use implicit True/False evaluation instead of len (WTF)
        while self._out_queue and length < max_size:
            message = self._out_queue.pop()[1]
            queue.job.Append(message)
            stats.STATS.IncrementCounter("grr_client_sent_messages")

            # Maintain the output queue tally
            length += len(message.args)
            self._out_queue_size -= len(message.args)

        # Restore the old order.
        self._out_queue.reverse()

        return queue
Ejemplo n.º 6
0
    def testErrorDetection(self):
        """Tests the end to end encrypted communicators."""
        # Install the client - now we can verify its signed messages
        self.MakeClientAFF4Record()

        # Make something to send
        message_list = rdfvalue.MessageList()
        for i in range(0, 10):
            message_list.job.Append(session_id=str(i))

        result = rdfvalue.ClientCommunication()
        self.client_communicator.EncodeMessages(message_list, result)
        cipher_text = result.SerializeToString()

        # Depending on this modification several things may happen:
        # 1) The padding may not match which will cause a decryption exception.
        # 2) The protobuf may fail to decode causing a decoding exception.
        # 3) The modification may affect the signature resulting in UNAUTHENTICATED
        #    messages.
        # 4) The modification may have no effect on the data at all.
        for x in range(0, len(cipher_text), 50):
            # Futz with the cipher text (Make sure it's really changed)
            mod_cipher_text = (cipher_text[:x] +
                               chr((ord(cipher_text[x]) % 250) + 1) +
                               cipher_text[x + 1:])

            try:
                decoded, client_id, _ = self.server_communicator.DecryptMessage(
                    mod_cipher_text)

                for i, message in enumerate(decoded):
                    # If the message is actually authenticated it must not be changed!
                    if message.auth_state == message.AuthorizationState.AUTHENTICATED:
                        self.assertEqual(message.source, client_id)

                        # These fields are set by the decoder and are not present in the
                        # original message - so we clear them before comparison.
                        message.auth_state = None
                        message.source = None
                        self.assertProtoEqual(message, message_list.job[i])
                    else:
                        logging.debug("Message %s: Authstate: %s", i,
                                      message.auth_state)

            except communicator.DecodingError as e:
                logging.debug("Detected alteration at %s: %s", x, e)
Ejemplo n.º 7
0
  def testDrainUpdateSessionRequestStates(self):
    """Draining the flow requests and preparing messages."""
    # This flow sends 10 messages on Start()
    flow_obj = self.FlowSetup("SendingTestFlow")
    session_id = flow_obj.session_id

    # There should be 10 messages in the client's task queue
    manager = queue_manager.QueueManager(token=self.token)
    tasks = manager.Query(self.client_id, 100)
    self.assertEqual(len(tasks), 10)

    # Check that the response state objects have the correct ts_id set
    # in the client_queue:
    for task in tasks:
      request_id = task.request_id

      # Retrieve the request state for this request_id
      request_state, _ = data_store.DB.Resolve(
          session_id.Add("state"),
          manager.FLOW_REQUEST_TEMPLATE % request_id,
          token=self.token)

      request_state = rdfvalue.RequestState(request_state)

      # Check that task_id for the client message is correctly set in
      # request_state.
      self.assertEqual(request_state.request.task_id, task.task_id)

    # Now ask the server to drain the outbound messages into the
    # message list.
    response = rdfvalue.MessageList()

    self.server.DrainTaskSchedulerQueueForClient(
        self.client_id, 5, response)

    # Check that we received only as many messages as we asked for
    self.assertEqual(len(response.job), 5)

    for i in range(4):
      self.assertEqual(response.job[i].session_id, session_id)
      self.assertEqual(response.job[i].name, "Test")
Ejemplo n.º 8
0
  def Drain(self, max_size=1024):
    """Return a GrrQueue message list from the queue, draining it.

    This is used to get the messages going _TO_ the server when the
    client connects.

    Args:
       max_size: The size of the returned protobuf will be at most one
       message length over this size.

    Returns:
       A MessageList protobuf
    """
    queue = rdfvalue.MessageList()

    length = 0
    self._out_queue.sort(key=lambda msg: msg[0])

    # Front pops are quadratic so we reverse the queue.
    self._out_queue.reverse()

    while self._out_queue and length < max_size:
      message = self._out_queue.pop()[1]
      queue.job.Append(message)
      stats.STATS.IncrementCounter("grr_client_sent_messages")

      # We deliberately look at the serialized length as bytes here.
      message_length = len(message.Get("args"))

      # Maintain the output queue tally
      length += message_length
      self._out_queue_size -= message_length

    # Restore the old order.
    self._out_queue.reverse()

    return queue
Ejemplo n.º 9
0
    def UrlMock(self, req, num_messages=10, **kwargs):
        """A mock for url handler processing from the server's POV."""
        if "server.pem" in req.get_full_url():
            return StringIO.StringIO(config_lib.CONFIG["Frontend.certificate"])

        _ = kwargs
        try:
            self.client_communication = rdfvalue.ClientCommunication(req.data)

            # Decrypt incoming messages
            self.messages, source, ts = self.server_communicator.DecodeMessages(
                self.client_communication)

            # Make sure the messages are correct
            self.assertEqual(source, self.client_cn)
            for i, message in enumerate(self.messages):
                # Do not check any status messages.
                if message.request_id:
                    self.assertEqual(message.response_id, i)
                    self.assertEqual(message.request_id, 1)
                    self.assertEqual(message.session_id, "aff4:/W:session")

            # Now prepare a response
            response_comms = rdfvalue.ClientCommunication()
            message_list = rdfvalue.MessageList()
            for i in range(0, num_messages):
                message_list.job.Append(request_id=i, **self.server_response)

            # Preserve the timestamp as a nonce
            self.server_communicator.EncodeMessages(
                message_list,
                response_comms,
                destination=source,
                timestamp=ts,
                api_version=self.client_communication.api_version)

            return StringIO.StringIO(response_comms.SerializeToString())
        except communicator.RekeyError:
            raise urllib2.HTTPError(url=None,
                                    code=400,
                                    msg=None,
                                    hdrs=None,
                                    fp=None)
        except communicator.UnknownClientCert:
            raise urllib2.HTTPError(url=None,
                                    code=406,
                                    msg=None,
                                    hdrs=None,
                                    fp=None)
        except Exception as e:
            logging.info("Exception in mock urllib2.Open: %s.", e)
            self.last_urlmock_error = e

            if flags.FLAGS.debug:
                pdb.post_mortem()

            raise urllib2.HTTPError(url=None,
                                    code=500,
                                    msg=None,
                                    hdrs=None,
                                    fp=None)
Ejemplo n.º 10
0
    def RunOnce(self):
        """Makes a single request to the GRR server.

    Returns:
      A Status() object indicating how the last POST went.
    """
        try:
            status = Status()

            # Here we only drain messages if we were able to connect to the server in
            # the last poll request. Otherwise we just wait until the connection comes
            # back so we don't expire our messages too fast.
            if self.consecutive_connection_errors == 0:
                # Grab some messages to send
                message_list = self.client_worker.Drain(
                    max_size=config_lib.CONFIG["Client.max_post_size"])
            else:
                message_list = rdfvalue.MessageList()

            sent_count = 0
            sent = {}
            require_fastpoll = False

            for message in message_list.job:
                sent_count += 1

                require_fastpoll |= message.require_fastpoll

                sent.setdefault(message.priority, 0)
                sent[message.priority] += 1

            status = Status(sent_count=sent_count,
                            sent=sent,
                            require_fastpoll=require_fastpoll)

            # Make new encrypted ClientCommunication rdfvalue.
            payload = rdfvalue.ClientCommunication()

            # If our memory footprint is too large, we advertise that our input queue
            # is full. This will prevent the server from sending us any messages, and
            # hopefully allow us to work down our memory usage, by processing any
            # outstanding messages.
            if self.client_worker.MemoryExceeded():
                logging.info("Memory exceeded, will not retrieve jobs.")
                payload.queue_size = 1000000
            else:
                # Let the server know how many messages are currently queued in
                # the input queue.
                payload.queue_size = self.client_worker.InQueueSize()

            nonce = self.communicator.EncodeMessages(message_list, payload)
            response = self.MakeRequest(payload.SerializeToString(), status)

            if status.code != 200:
                # We don't print response here since it should be encrypted and will
                # cause ascii conversion errors.
                logging.info(
                    "%s: Could not connect to server at %s, status %s",
                    self.communicator.common_name, self.GetServerUrl(),
                    status.code)

                # Reschedule the tasks back on the queue so they get retried next time.
                messages = list(message_list.job)
                for message in messages:
                    message.priority = rdfvalue.GrrMessage.Priority.HIGH_PRIORITY
                    message.require_fastpoll = False
                    message.ttl -= 1
                    if message.ttl > 0:
                        # Schedule with high priority to make it jump the queue.
                        self.client_worker.QueueResponse(
                            message,
                            rdfvalue.GrrMessage.Priority.HIGH_PRIORITY + 1)
                    else:
                        logging.info("Dropped message due to retransmissions.")
                return status

            if not response:
                return status

            try:
                tmp = self.communicator.DecryptMessage(response)
                (messages, source, server_nonce) = tmp

                if server_nonce != nonce:
                    logging.info("Nonce not matched.")
                    status.code = 500
                    return status

            except proto2_message.DecodeError:
                logging.info("Protobuf decode error. Bad URL or auth.")
                status.code = 500
                return status

            if source != self.communicator.server_name:
                logging.info(
                    "Received a message not from the server "
                    "%s, expected %s.", source, self.communicator.server_name)
                status.code = 500
                return status

            status.received_count = len(messages)

            # If we're not going to fastpoll based on outbound messages, check to see
            # if any inbound messages want us to fastpoll. This means we drop to
            # fastpoll immediately on a new request rather than waiting for the next
            # beacon to report results.
            if not status.require_fastpoll:
                for message in messages:
                    if message.require_fastpoll:
                        status.require_fastpoll = True
                        break

            # Process all messages. Messages can be processed by clients in
            # any order since clients do not have state.
            self.client_worker.QueueMessages(messages)

        except Exception:  # pylint: disable=broad-except
            # Catch everything, yes, this is terrible but necessary
            logging.warn("Uncaught exception caught: %s",
                         traceback.format_exc())
            if status:
                status.code = 500
            if flags.FLAGS.debug:
                pdb.post_mortem()

        return status