Ejemplo n.º 1
0
    def testMessageHandlerRequests(self):

        requests = [
            rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000",
                                              handler_name="Testhandler",
                                              request_id=i * 100,
                                              request=rdfvalue.RDFInteger(i))
            for i in range(5)
        ]

        self.db.WriteMessageHandlerRequests(requests)

        read = self.db.ReadMessageHandlerRequests()
        for r in read:
            self.assertTrue(r.timestamp)
            r.timestamp = None

        self.assertEqual(sorted(read, key=lambda req: req.request_id),
                         requests)

        self.db.DeleteMessageHandlerRequests(requests[:2])
        self.db.DeleteMessageHandlerRequests(requests[4])

        read = self.db.ReadMessageHandlerRequests()
        self.assertEqual(len(read), 2)
        read = sorted(read, key=lambda req: req.request_id)
        for r in read:
            r.timestamp = None

        self.assertEqual(requests[2:4], read)
Ejemplo n.º 2
0
    def Flush(self):
        """Writes the changes in this object to the datastore."""

        if data_store.RelationalDBReadEnabled(category="message_handlers"):
            message_handler_requests = []
            leftover_responses = []

            for r, timestamp in self.response_queue:
                if r.request_id == 0 and r.session_id in session_id_map:
                    message_handler_requests.append(
                        rdf_objects.MessageHandlerRequest(
                            client_id=r.source and r.source.Basename(),
                            handler_name=session_id_map[r.session_id],
                            request_id=r.response_id,
                            request=r.payload))
                else:
                    leftover_responses.append((r, timestamp))

            if message_handler_requests:
                data_store.REL_DB.WriteMessageHandlerRequests(
                    message_handler_requests)
            self.response_queue = leftover_responses

        self.data_store.StoreRequestsAndResponses(
            new_requests=self.request_queue,
            new_responses=self.response_queue,
            requests_to_delete=self.requests_to_delete)

        # We need to make sure that notifications are written after the requests so
        # we flush after writing all requests and only notify afterwards.
        mutation_pool = self.data_store.GetMutationPool()
        with mutation_pool:
            for client_id, messages in self.client_messages_to_delete.iteritems(
            ):
                self.Delete(client_id.Queue(),
                            messages,
                            mutation_pool=mutation_pool)

            if self.new_client_messages:
                for timestamp, messages in utils.GroupBy(
                        self.new_client_messages, lambda x: x[1]).iteritems():

                    self.Schedule([x[0] for x in messages],
                                  timestamp=timestamp,
                                  mutation_pool=mutation_pool)

        if self.notifications:
            for notification in self.notifications.itervalues():
                self.NotifyQueue(notification, mutation_pool=mutation_pool)

            mutation_pool.Flush()

        self.request_queue = []
        self.response_queue = []
        self.requests_to_delete = []

        self.client_messages_to_delete = {}
        self.notifications = {}
        self.new_client_messages = []
Ejemplo n.º 3
0
  def testMessageHandlerRequestSorting(self):

    for i, ts in enumerate(
        [10000, 11000, 12000, 13000, 14000, 19000, 18000, 17000, 16000, 15000]):
      with test_lib.FakeTime(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(ts)):
        request = rdf_objects.MessageHandlerRequest(
            client_id="C.1000000000000000",
            handler_name="Testhandler",
            request_id=i * 100,
            request=rdfvalue.RDFInteger(i))
        self.db.WriteMessageHandlerRequests([request])

    read = self.db.ReadMessageHandlerRequests()

    for i in range(1, len(read)):
      self.assertGreater(read[i - 1].timestamp, read[i].timestamp)
Ejemplo n.º 4
0
    def PushToStateQueue(self, manager, message, **kw):
        """Push given message to the state queue."""

        # Assume the client is authorized
        message.auth_state = rdf_flows.GrrMessage.AuthorizationState.AUTHENTICATED

        # Update kw args
        for k, v in kw.items():
            setattr(message, k, v)

        # Handle well known flows
        if message.request_id == 0:

            # Well known flows only accept messages of type MESSAGE.
            if message.type == rdf_flows.GrrMessage.Type.MESSAGE:
                # Assume the message is authenticated and comes from this client.
                message.source = self.client_id

                message.auth_state = "AUTHENTICATED"

                session_id = message.session_id
                if session_id:
                    handler_name = queue_manager.session_id_map.get(session_id)
                    if handler_name:
                        logging.info("Running message handler: %s",
                                     handler_name)
                        handler_cls = handler_registry.handler_name_map.get(
                            handler_name)
                        handler_request = rdf_objects.MessageHandlerRequest(
                            client_id=self.client_id.Basename(),
                            handler_name=handler_name,
                            request_id=message.response_id,
                            request=message.payload)

                        handler_cls(
                            token=self.token).ProcessMessages(handler_request)
                    else:
                        logging.info("Running well known flow: %s", session_id)
                        self.well_known_flows[
                            session_id.FlowName()].ProcessMessage(message)

            return

        manager.QueueResponse(message)
Ejemplo n.º 5
0
    def testMessageHandlerRequestLeasing(self):

        requests = [
            rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000",
                                              handler_name="Testhandler",
                                              request_id=i * 100,
                                              request=rdfvalue.RDFInteger(i))
            for i in range(10)
        ]
        lease_time = rdfvalue.Duration("5m")

        with test_lib.FakeTime(
                rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)):
            self.db.WriteMessageHandlerRequests(requests)

        t0 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000)
        with test_lib.FakeTime(t0):
            t0_expiry = t0 + lease_time
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t0_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

        t1 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 100)
        with test_lib.FakeTime(t1):
            t1_expiry = t1 + lease_time
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t1_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

            # Nothing left to lease.
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=2)

            self.assertEqual(len(leased), 0)

        read = self.db.ReadMessageHandlerRequests()

        self.assertEqual(len(read), 10)
        for r in read:
            self.assertEqual(r.leased_by, utils.ProcessIdString())

        self.assertEqual(len([r for r in read if r.leased_until == t0_expiry]),
                         5)
        self.assertEqual(len([r for r in read if r.leased_until == t1_expiry]),
                         5)

        # Half the leases expired.
        t2 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 350)
        with test_lib.FakeTime(t2):
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time)

            self.assertEqual(len(leased), 5)

        # All of them expired.
        t3 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 10350)
        with test_lib.FakeTime(t3):
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time)

            self.assertEqual(len(leased), 10)