Esempio n. 1
0
    def WriteFlowRequests(self, requests):
        """Writes a list of flow requests to the database."""
        flow_processing_requests = []

        for request in requests:
            if (request.client_id, request.flow_id) not in self.flows:
                raise db.AtLeastOneUnknownFlowError([(request.client_id,
                                                      request.flow_id)])

        for request in requests:
            key = (request.client_id, request.flow_id)
            request_dict = self.flow_requests.setdefault(key, {})
            request_dict[request.request_id] = request.Copy()
            request_dict[
                request.request_id].timestamp = rdfvalue.RDFDatetime.Now()

            if request.needs_processing:
                flow = self.flows[(request.client_id, request.flow_id)]
                if flow.next_request_to_process == request.request_id:
                    flow_processing_requests.append(
                        rdf_flows.FlowProcessingRequest(
                            client_id=request.client_id,
                            flow_id=request.flow_id))

        if flow_processing_requests:
            self.WriteFlowProcessingRequests(flow_processing_requests)
Esempio n. 2
0
    def testFlowProcessingRequestsQueue(self):
        client_id, _ = self._SetupClientAndFlow()
        flow_ids = [u"1234ABC%d" % i for i in range(5)]

        queue = Queue.Queue()

        def Callback(request):
            self.db.AckFlowProcessingRequests([request])
            queue.put(request)

        self.db.RegisterFlowProcessingHandler(Callback)

        requests = []
        for flow_id in flow_ids:
            requests.append(
                rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                flow_id=flow_id))

        self.db.WriteFlowProcessingRequests(requests)

        got = []
        while len(got) < 5:
            try:
                l = queue.get(True, timeout=6)
            except Queue.Empty:
                self.fail(
                    "Timed out waiting for messages, expected 5, got %d" %
                    len(got))
            got.append(l)

        got.sort(key=lambda req: req.flow_id)
        self.assertEqual(requests, got)

        self.db.UnregisterFlowProcessingHandler()
Esempio n. 3
0
    def testAcknowledgingFlowProcessingRequestsWorks(self):
        client_id, _ = self._SetupClientAndFlow()
        flow_ids = [u"1234ABC%d" % i for i in range(5)]

        now = rdfvalue.RDFDatetime.Now()
        delivery_time = now + rdfvalue.Duration("10m")
        requests = []
        for flow_id in flow_ids:
            requests.append(
                rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                flow_id=flow_id,
                                                delivery_time=delivery_time))

        self.db.WriteFlowProcessingRequests(requests)

        # We stored 5 FlowProcessingRequests, read them back and check they are all
        # there.
        stored_requests = self.db.ReadFlowProcessingRequests()
        stored_requests.sort(key=lambda r: r.flow_id)
        self.assertEqual(len(stored_requests), 5)
        self.assertEqual([r.flow_id for r in stored_requests], flow_ids)

        # Now we ack requests 1 and 2. There should be three remaining in the db.
        self.db.AckFlowProcessingRequests(stored_requests[1:3])
        stored_requests = self.db.ReadFlowProcessingRequests()
        self.assertEqual(len(stored_requests), 3)
        self.assertItemsEqual([r.flow_id for r in stored_requests],
                              [flow_ids[0], flow_ids[3], flow_ids[4]])

        # Make sure DeleteAllFlowProcessingRequests removes all requests.
        self.db.DeleteAllFlowProcessingRequests()
        self.assertEqual(self.db.ReadFlowProcessingRequests(), [])

        self.db.UnregisterFlowProcessingHandler()
Esempio n. 4
0
    def testFlowProcessingRequestsQueueWithDelay(self):
        client_id, flow_id = self._SetupClientAndFlow()

        queue = Queue.Queue()
        self.db.RegisterFlowProcessingHandler(queue.put)

        now = rdfvalue.RDFDatetime.Now()
        delivery_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(
            now.AsSecondsSinceEpoch() + 0.5)
        requests = []
        for i in xrange(5):
            requests.append(
                rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                flow_id=flow_id,
                                                request_id=i,
                                                delivery_time=delivery_time))

        self.db.WriteFlowProcessingRequests(requests)

        got = []
        while len(got) < 5:
            try:
                l = queue.get(True, timeout=6)
            except Queue.Empty:
                self.fail(
                    "Timed out waiting for messages, expected 5, got %d" %
                    len(got))
            got.append(l)
            self.assertGreater(rdfvalue.RDFDatetime.Now(), l.delivery_time)

        got.sort(key=lambda req: req.request_id)
        self.assertEqual(requests, got)

        self.db.UnregisterFlowProcessingHandler()
Esempio n. 5
0
    def testFlowProcessingRequestsQueue(self):
        client_id, flow_id = self._SetupClientAndFlow()

        queue = Queue.Queue()
        self.db.RegisterFlowProcessingHandler(queue.put)

        requests = []
        for i in xrange(5):
            requests.append(
                rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                flow_id=flow_id,
                                                request_id=i))

        self.db.WriteFlowProcessingRequests(requests)

        got = []
        while len(got) < 5:
            try:
                l = queue.get(True, timeout=6)
            except Queue.Empty:
                self.fail(
                    "Timed out waiting for messages, expected 5, got %d" %
                    len(got))
            got.append(l)

        got.sort(key=lambda req: req.request_id)
        self.assertEqual(requests, got)

        self.db.UnregisterFlowProcessingHandler()
Esempio n. 6
0
    def WriteFlowRequests(self, requests, cursor=None):
        """Writes a list of flow requests to the database."""
        args = []
        templates = []
        flow_keys = []
        needs_processing = {}
        now_str = mysql_utils.RDFDatetimeToMysqlString(
            rdfvalue.RDFDatetime.Now())
        for r in requests:
            if r.needs_processing:
                needs_processing.setdefault((r.client_id, r.flow_id),
                                            []).append(r.request_id)

            flow_keys.append((r.client_id, r.flow_id))
            templates.append("(%s, %s, %s, %s, %s, %s)")
            args.extend([
                mysql_utils.ClientIDToInt(r.client_id),
                mysql_utils.FlowIDToInt(r.flow_id), r.request_id,
                r.needs_processing,
                r.SerializeToString(), now_str
            ])

        if needs_processing:
            flow_processing_requests = []
            nr_conditions = []
            nr_args = []
            for client_id, flow_id in needs_processing:
                nr_conditions.append("(client_id=%s AND flow_id=%s)")
                nr_args.append(mysql_utils.ClientIDToInt(client_id))
                nr_args.append(mysql_utils.FlowIDToInt(flow_id))

            nr_query = ("SELECT client_id, flow_id, next_request_to_process "
                        "FROM flows WHERE ")
            nr_query += " OR ".join(nr_conditions)

            cursor.execute(nr_query, nr_args)

            db_result = cursor.fetchall()
            for client_id_int, flow_id_int, next_request_to_process in db_result:
                client_id = mysql_utils.IntToClientID(client_id_int)
                flow_id = mysql_utils.IntToFlowID(flow_id_int)
                if next_request_to_process in needs_processing[(client_id,
                                                                flow_id)]:
                    flow_processing_requests.append(
                        rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                        flow_id=flow_id))

            if flow_processing_requests:
                self._WriteFlowProcessingRequests(flow_processing_requests,
                                                  cursor)

        query = ("INSERT INTO flow_requests "
                 "(client_id, flow_id, request_id, needs_processing, request, "
                 "timestamp) VALUES ")
        query += ", ".join(templates)
        try:
            cursor.execute(query, args)
        except MySQLdb.IntegrityError as e:
            raise db.AtLeastOneUnknownFlowError(flow_keys, cause=e)
Esempio n. 7
0
 def testRaisesIfFlowProcessingRequestDoesNotTriggerAnyProcessing(self):
     with flow_test_lib.TestWorker() as worker:
         flow_id = flow.StartFlow(flow_cls=CallClientParentFlow,
                                  client_id=self.client_id)
         fpr = rdf_flows.FlowProcessingRequest(client_id=self.client_id,
                                               flow_id=flow_id)
         with self.assertRaises(worker_lib.FlowHasNothingToProcessError):
             worker.ProcessFlow(fpr)
Esempio n. 8
0
    def WriteFlowResponses(self, responses):
        """Writes a list of flow responses to the database."""
        status_available = set()
        requests_updated = set()

        for response in responses:
            flow_key = (response.client_id, response.flow_id)
            if flow_key not in self.flows:
                raise db.UnknownFlowError(response.client_id, response.flow_id)

            request_dict = self.flow_requests.get(flow_key, {})
            if response.request_id not in request_dict:
                logging.error(
                    "Received response for unknown request %s, %s, %d.",
                    response.client_id, response.flow_id, response.request_id)
                continue

            response_dict = self.flow_responses.setdefault(flow_key, {})
            response_dict.setdefault(
                response.request_id,
                {})[response.response_id] = response.Copy()

            if isinstance(response, rdf_flow_objects.FlowStatus):
                status_available.add(response)

            requests_updated.add(
                (response.client_id, response.flow_id, response.request_id))

        # Every time we get a status we store how many responses are expected.
        for status in status_available:
            request_dict = self.flow_requests[(status.client_id,
                                               status.flow_id)]
            request = request_dict[status.request_id]
            request.nr_responses_expected = status.response_id

        # And we check for all updated requests if we need to process them.
        needs_processing = []

        for client_id, flow_id, request_id in requests_updated:
            flow_key = (client_id, flow_id)
            request_dict = self.flow_requests[flow_key]
            request = request_dict[request_id]
            if request.nr_responses_expected and not request.needs_processing:
                response_dict = self.flow_responses.setdefault(flow_key, {})
                responses = response_dict.get(request_id, {})

                if len(responses) == request.nr_responses_expected:
                    request.needs_processing = True
                    flow = self.flows[flow_key]
                    if flow.next_request_to_process == request_id:
                        needs_processing.append(
                            rdf_flows.FlowProcessingRequest(
                                client_id=client_id,
                                flow_id=flow_id,
                                request_id=request_id))

        if needs_processing:
            self.WriteFlowProcessingRequests(needs_processing)
Esempio n. 9
0
    def testFlowProcessingRequestDeletion(self):
        client_id, flow_id = self._SetupClientAndFlow()
        now = rdfvalue.RDFDatetime.Now()
        delivery_time = now + rdfvalue.Duration("10m")
        requests = []
        for i in xrange(5):
            requests.append(
                rdf_flows.FlowProcessingRequest(client_id=client_id,
                                                flow_id=flow_id,
                                                request_id=i,
                                                delivery_time=delivery_time))

        self.db.WriteFlowProcessingRequests(requests)

        stored_requests = self.db.ReadFlowProcessingRequests()
        self.assertEqual(len(stored_requests), 5)
        self.assertItemsEqual([r.request_id for r in stored_requests],
                              [0, 1, 2, 3, 4])

        self.db.DeleteFlowProcessingRequests(requests[1:3])
        stored_requests = self.db.ReadFlowProcessingRequests()
        self.assertEqual(len(stored_requests), 3)
        self.assertItemsEqual([r.request_id for r in stored_requests],
                              [0, 3, 4])
Esempio n. 10
0
  def WriteFlowResponses(self, responses, cursor=None):
    """Writes a list of flow responses to the database."""

    if not responses:
      return

    # In addition to just writing responses, this function needs to also

    # - Update the expected nr of response for each request that received a
    #   status.
    # - Set the needs_processing flag for all requests that now have all
    #   responses.
    # - Send FlowProcessingRequests for all flows that are waiting on a request
    #   whose needs_processing flag was just set.

    # To achieve this, we need to get the next request each flow is waiting for.
    next_request_by_flow = {}

    # And the number of responses each affected request is waiting for (if
    # available).
    responses_expected_by_request = {}

    # As well as the ids of the currently available responses for each request.
    current_responses_by_request = {}

    # We also store all requests we have in the db so we can discard responses
    # for unknown requests right away.
    currently_available_requests = set()

    self._ReadCurrentFlowInfo(
        responses, currently_available_requests, next_request_by_flow,
        responses_expected_by_request, current_responses_by_request, cursor)

    # For some requests we will need to update the number of expected responses.
    needs_expected_update = {}

    # For some we will need to update the needs_processing flag.
    needs_processing_update = set()

    # Some completed requests will trigger a flow processing request, we collect
    # them in:
    flow_processing_requests = []

    task_ids_by_request = {}

    for r in responses:
      request_key = (r.client_id, r.flow_id, r.request_id)

      try:
        # If this is a response coming from a client, a task_id will be set. We
        # store it in case the request is complete and we can remove the client
        # messages.
        task_ids_by_request[request_key] = r.task_id
      except AttributeError:
        pass

      if not isinstance(r, rdf_flow_objects.FlowStatus):
        continue

      current = responses_expected_by_request.get(request_key)
      if current:
        logging.error("Got duplicate status message for request %s/%s/%d",
                      r.client_id, r.flow_id, r.request_id)
        # If there is already responses_expected information, we need to make
        # sure the current status doesn't disagree.
        if current != r.response_id:
          raise ValueError(
              "Got conflicting status information for request %s: %s" %
              (request_key, r))
      else:
        needs_expected_update[request_key] = r.response_id

      responses_expected_by_request[request_key] = r.response_id

    responses_to_write = []
    client_messages_to_delete = []
    for r in responses:
      request_key = (r.client_id, r.flow_id, r.request_id)

      if request_key not in currently_available_requests:
        logging.info("Dropping response for unknown request %s/%s/%d",
                     r.client_id, r.flow_id, r.request_id)
        continue

      responses_to_write.append(r)

      current_responses = current_responses_by_request.setdefault(
          request_key, set())
      if r.response_id in current_responses:
        # We have this response already, nothing further to do.
        continue

      current_responses.add(r.response_id)
      expected_responses = responses_expected_by_request.get(request_key, 0)
      if len(current_responses) == expected_responses:
        # This response was the one that was missing, time to set the
        # needs_processing flag.
        needs_processing_update.add(request_key)
        if r.request_id == next_request_by_flow[(r.client_id, r.flow_id)]:
          # The request that is now ready for processing was also the one the
          # flow was waiting for.
          req = rdf_flows.FlowProcessingRequest(
              client_id=r.client_id, flow_id=r.flow_id)
          flow_processing_requests.append(req)

        # Since this request is now complete, we can remove the corresponding
        # client messages if there are any.
        task_id = task_ids_by_request.get(request_key, None)
        if task_id is not None:
          client_messages_to_delete.append((r.client_id, task_id))

    now = rdfvalue.RDFDatetime.Now()
    now_str = mysql_utils.RDFDatetimeToMysqlString(now)

    if responses_to_write:
      self._WriteResponses(responses_to_write, now_str, cursor)

    self._UpdateRequests(needs_processing_update, needs_expected_update, cursor)

    if client_messages_to_delete:
      self._DeleteClientMessages(client_messages_to_delete, cursor)

    if flow_processing_requests:
      self._WriteFlowProcessingRequests(flow_processing_requests, cursor)
Esempio n. 11
0
    def WriteFlowResponses(self, responses):
        """Writes a list of flow responses to the database."""
        status_available = set()
        requests_updated = set()
        task_ids_by_request = {}

        for response in responses:
            flow_key = (response.client_id, response.flow_id)
            if flow_key not in self.flows:
                logging.error("Received response for unknown flow %s, %s.",
                              response.client_id, response.flow_id)
                continue

            request_dict = self.flow_requests.get(flow_key, {})
            if response.request_id not in request_dict:
                logging.error(
                    "Received response for unknown request %s, %s, %d.",
                    response.client_id, response.flow_id, response.request_id)
                continue

            response_dict = self.flow_responses.setdefault(flow_key, {})
            clone = response.Copy()
            clone.timestamp = rdfvalue.RDFDatetime.Now()

            response_dict.setdefault(response.request_id,
                                     {})[response.response_id] = clone

            if isinstance(response, rdf_flow_objects.FlowStatus):
                status_available.add(response)

            request_key = (response.client_id, response.flow_id,
                           response.request_id)
            requests_updated.add(request_key)
            try:
                task_ids_by_request[request_key] = response.task_id
            except AttributeError:
                pass

        # Every time we get a status we store how many responses are expected.
        for status in status_available:
            request_dict = self.flow_requests[(status.client_id,
                                               status.flow_id)]
            request = request_dict[status.request_id]
            request.nr_responses_expected = status.response_id

        # And we check for all updated requests if we need to process them.
        needs_processing = []
        for client_id, flow_id, request_id in requests_updated:
            flow_key = (client_id, flow_id)
            request_dict = self.flow_requests[flow_key]
            request = request_dict[request_id]
            if request.nr_responses_expected and not request.needs_processing:
                response_dict = self.flow_responses.setdefault(flow_key, {})
                responses = response_dict.get(request_id, {})

                if len(responses) == request.nr_responses_expected:
                    request.needs_processing = True
                    task_id = task_ids_by_request.get(
                        (client_id, flow_id, request_id))
                    if task_id:
                        self._DeleteClientMessage(client_id, task_id)

                    flow = self.flows[flow_key]
                    if flow.next_request_to_process == request_id:
                        needs_processing.append(
                            rdf_flows.FlowProcessingRequest(
                                client_id=client_id, flow_id=flow_id))

        if needs_processing:
            self.WriteFlowProcessingRequests(needs_processing)