Beispiel #1
0
    def testReadFlowForProcessingUpdatesFlowObjects(self):
        client_id, flow_id = self._SetupClientAndFlow()

        now = rdfvalue.RDFDatetime.Now()
        processing_time = rdfvalue.Duration("60s")
        processing_deadline = now + processing_time

        with test_lib.FakeTime(now):
            flow_for_processing = self.db.ReadFlowForProcessing(
                client_id, flow_id, processing_time)

        self.assertEqual(flow_for_processing.processing_on,
                         utils.ProcessIdString())
        self.assertEqual(flow_for_processing.processing_since, now)
        self.assertEqual(flow_for_processing.processing_deadline,
                         processing_deadline)

        read_flow = self.db.ReadFlowObject(client_id, flow_id)
        self.assertEqual(read_flow.processing_on, utils.ProcessIdString())
        self.assertEqual(read_flow.processing_since, now)
        self.assertEqual(read_flow.processing_deadline, processing_deadline)

        flow_for_processing.next_request_to_process = 5

        self.db.ReturnProcessedFlow(flow_for_processing)
        self.assertFalse(flow_for_processing.processing_on)
        self.assertIsNone(flow_for_processing.processing_since)
        self.assertIsNone(flow_for_processing.processing_deadline)

        read_flow = self.db.ReadFlowObject(client_id, flow_id)
        self.assertFalse(read_flow.processing_on)
        self.assertIsNone(read_flow.processing_since)
        self.assertIsNone(read_flow.processing_deadline)
        self.assertEqual(read_flow.next_request_to_process, 5)
Beispiel #2
0
    def LeaseClientMessages(self,
                            client_id,
                            lease_time=None,
                            limit=sys.maxsize):
        """Leases available client messages for the client with the given id."""
        leased_messages = []

        now = rdfvalue.RDFDatetime.Now()
        expiration_time = now + lease_time
        process_id_str = utils.ProcessIdString()

        leases = self.client_message_leases
        for msgs_by_id in itervalues(self.client_messages):
            for msg in sorted(itervalues(msgs_by_id), key=lambda m: m.task_id):
                if db_utils.ClientIdFromGrrMessage(msg) != client_id:
                    continue

                existing_lease = leases.get(msg.task_id)
                if not existing_lease or existing_lease[0] < now:
                    leases[msg.task_id] = (expiration_time, process_id_str)
                    msg.leased_until = expiration_time
                    msg.leased_by = process_id_str
                    leased_messages.append(msg)
                    if len(leased_messages) >= limit:
                        break

        return leased_messages
Beispiel #3
0
  def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None):
    """Leases a number of message handler requests up to the indicated limit."""

    now = rdfvalue.RDFDatetime.Now()
    now_str = mysql_utils.RDFDatetimeToMysqlString(now)

    expiry = now + lease_time
    expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)

    query = ("UPDATE message_handler_requests "
             "SET leased_until=%s, leased_by=%s "
             "WHERE leased_until IS NULL OR leased_until < %s "
             "LIMIT %s")

    id_str = utils.ProcessIdString()
    args = (expiry_str, id_str, now_str, limit)
    updated = cursor.execute(query, args)

    if updated == 0:
      return []

    cursor.execute(
        "SELECT timestamp, request FROM message_handler_requests "
        "WHERE leased_by=%s AND leased_until=%s LIMIT %s",
        (id_str, expiry_str, updated))
    res = []
    for timestamp, request in cursor.fetchall():
      req = rdf_objects.MessageHandlerRequest.FromSerializedString(request)
      req.timestamp = mysql_utils.MysqlToRDFDatetime(timestamp)
      req.leased_until = expiry
      req.leased_by = id_str
      res.append(req)

    return res
Beispiel #4
0
  def ReadFlowForProcessing(self, client_id, flow_id, processing_time):
    """Marks a flow as being processed on this worker and returns it."""
    rdf_flow = self.ReadFlowObject(client_id, flow_id)
    # TODO(user): remove the check for a legacy hunt prefix as soon as
    # AFF4 is gone.
    if rdf_flow.parent_hunt_id and not rdf_flow.parent_hunt_id.startswith("H:"):
      rdf_hunt = self.ReadHuntObject(rdf_flow.parent_hunt_id)
      if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing(
          rdf_hunt.hunt_state):
        raise db.ParentHuntIsNotRunningError(
            client_id, flow_id, rdf_hunt.hunt_id, rdf_hunt.hunt_state)

    now = rdfvalue.RDFDatetime.Now()
    if rdf_flow.processing_on and rdf_flow.processing_deadline > now:
      raise ValueError("Flow %s on client %s is already being processed." %
                       (client_id, flow_id))
    processing_deadline = now + processing_time
    process_id_string = utils.ProcessIdString()
    self.UpdateFlow(
        client_id,
        flow_id,
        processing_on=process_id_string,
        processing_since=now,
        processing_deadline=processing_deadline)
    rdf_flow.processing_on = process_id_string
    rdf_flow.processing_since = now
    rdf_flow.processing_deadline = processing_deadline
    return rdf_flow
Beispiel #5
0
  def _LeaseFlowProcessingReqests(self, cursor=None):
    """Leases a number of flow processing requests."""
    now = rdfvalue.RDFDatetime.Now()
    now_str = mysql_utils.RDFDatetimeToMysqlString(now)

    expiry = now + rdfvalue.Duration("10m")
    expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)

    query = ("UPDATE flow_processing_requests "
             "SET leased_until=%s, leased_by=%s "
             "WHERE (delivery_time IS NULL OR delivery_time <= %s) AND "
             "(leased_until IS NULL OR leased_until < %s) "
             "LIMIT %s")

    id_str = utils.ProcessIdString()
    args = (expiry_str, id_str, now_str, now_str, 50)
    updated = cursor.execute(query, args)

    if updated == 0:
      return []

    cursor.execute(
        "SELECT timestamp, request FROM flow_processing_requests "
        "WHERE leased_by=%s AND leased_until=%s LIMIT %s",
        (id_str, expiry_str, updated))
    res = []
    for timestamp, request in cursor.fetchall():
      req = rdf_flows.FlowProcessingRequest.FromSerializedString(request)
      req.timestamp = mysql_utils.MysqlToRDFDatetime(timestamp)
      req.leased_until = expiry
      req.leased_by = id_str
      res.append(req)

    return res
Beispiel #6
0
    def LeaseCronJobs(self, cronjob_ids=None, lease_time=None, cursor=None):
        """Leases all available cron jobs."""
        now = rdfvalue.RDFDatetime.Now()
        now_str = mysql_utils.RDFDatetimeToTimestamp(now)
        expiry_str = mysql_utils.RDFDatetimeToTimestamp(now + lease_time)
        id_str = utils.ProcessIdString()

        query = (
            "UPDATE cron_jobs "
            "SET leased_until=FROM_UNIXTIME(%s), leased_by=%s "
            "WHERE (leased_until IS NULL OR leased_until < FROM_UNIXTIME(%s))")
        args = [expiry_str, id_str, now_str]

        if cronjob_ids:
            query += " AND job_id in (%s)" % ", ".join(
                ["%s"] * len(cronjob_ids))
            args += cronjob_ids

        updated = cursor.execute(query, args)

        if updated == 0:
            return []

        cursor.execute(
            "SELECT job, UNIX_TIMESTAMP(create_time), enabled,"
            "forced_run_requested, last_run_status, UNIX_TIMESTAMP(last_run_time), "
            "current_run_id, state, UNIX_TIMESTAMP(leased_until), leased_by "
            "FROM cron_jobs "
            "FORCE INDEX (cron_jobs_by_lease) "
            "WHERE leased_until=FROM_UNIXTIME(%s) AND leased_by=%s",
            [expiry_str, id_str])
        return [self._CronJobFromRow(row) for row in cursor.fetchall()]
Beispiel #7
0
  def LeaseFlowForProcessing(self, client_id, flow_id, processing_time):
    """Marks a flow as being processed on this worker and returns it."""
    rdf_flow = self.ReadFlowObject(client_id, flow_id)
    if rdf_flow.parent_hunt_id:
      rdf_hunt = self.ReadHuntObject(rdf_flow.parent_hunt_id)
      if not rdf_hunt_objects.IsHuntSuitableForFlowProcessing(
          rdf_hunt.hunt_state):
        raise db.ParentHuntIsNotRunningError(client_id, flow_id,
                                             rdf_hunt.hunt_id,
                                             rdf_hunt.hunt_state)

    now = rdfvalue.RDFDatetime.Now()
    if rdf_flow.processing_on and rdf_flow.processing_deadline > now:
      raise ValueError("Flow %s on client %s is already being processed." %
                       (flow_id, client_id))
    processing_deadline = now + processing_time
    process_id_string = utils.ProcessIdString()
    self.UpdateFlow(
        client_id,
        flow_id,
        processing_on=process_id_string,
        processing_since=now,
        processing_deadline=processing_deadline)
    rdf_flow.processing_on = process_id_string
    rdf_flow.processing_since = now
    rdf_flow.processing_deadline = processing_deadline
    return rdf_flow
Beispiel #8
0
  def LeaseClientMessages(self, client_id, lease_time=None, limit=sys.maxsize):
    """Leases available client messages for the client with the given id."""
    leased_messages = []

    now = rdfvalue.RDFDatetime.Now()
    expiration_time = now + lease_time
    process_id_str = utils.ProcessIdString()

    leases = self.client_message_leases
    for msgs_by_id in itervalues(self.client_messages):
      for msg in sorted(itervalues(msgs_by_id), key=lambda m: m.task_id):
        if db_utils.ClientIdFromGrrMessage(msg) != client_id:
          continue

        existing_lease = leases.get(msg.task_id)
        if not existing_lease or existing_lease[0] < now:
          if existing_lease:
            lease_count = existing_lease[-1] + 1
            # >= comparison since this check happens before the lease.
            if lease_count >= db.Database.CLIENT_MESSAGES_TTL:
              self._DeleteClientMessage(client_id, msg.task_id)
              continue
            leases[msg.task_id] = (expiration_time, process_id_str, lease_count)
          else:
            leases[msg.task_id] = (expiration_time, process_id_str, 0)

          msg.leased_until = expiration_time
          msg.leased_by = process_id_str
          leased_messages.append(msg)
          if len(leased_messages) >= limit:
            break

    return leased_messages
Beispiel #9
0
  def LeaseClientMessages(self,
                          client_id,
                          lease_time=None,
                          limit=None,
                          cursor=None):
    """Leases available client messages for the client with the given id."""

    now = rdfvalue.RDFDatetime.Now()
    now_str = mysql_utils.RDFDatetimeToMysqlString(now)
    expiry = now + lease_time
    expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)
    proc_id_str = utils.ProcessIdString()
    client_id_int = mysql_utils.ClientIDToInt(client_id)

    query = ("UPDATE client_messages "
             "SET leased_until=%s, leased_by=%s, leased_count=leased_count+1 "
             "WHERE client_id=%s AND "
             "(leased_until IS NULL OR leased_until < %s) "
             "LIMIT %s")
    args = [expiry_str, proc_id_str, client_id_int, now_str, limit]

    num_leased = cursor.execute(query, args)
    if num_leased == 0:
      return []

    query = ("SELECT message, leased_count FROM client_messages "
             "WHERE client_id=%s AND leased_until=%s AND leased_by=%s")

    cursor.execute(query, [client_id_int, expiry_str, proc_id_str])

    ret = []
    expired = []
    for msg, leased_count in cursor.fetchall():
      message = rdf_flows.GrrMessage.FromSerializedString(msg)
      message.leased_by = proc_id_str
      message.leased_until = expiry
      message.ttl = db.Database.CLIENT_MESSAGES_TTL - leased_count
      # > comparison since this check happens after the lease.
      if leased_count > db.Database.CLIENT_MESSAGES_TTL:
        expired.append((client_id, message.task_id))
      else:
        ret.append(message)

    if expired:
      self._DeleteClientMessages(expired, cursor=cursor)

    return sorted(ret, key=lambda msg: msg.task_id)
Beispiel #10
0
    def ReadFlowForProcessing(self,
                              client_id,
                              flow_id,
                              processing_time,
                              cursor=None):
        """Marks a flow as being processed on this worker and returns it."""
        query = ("SELECT " + self.FLOW_DB_FIELDS +
                 "FROM flows WHERE client_id=%s AND flow_id=%s")
        cursor.execute(query, [
            mysql_utils.ClientIDToInt(client_id),
            mysql_utils.FlowIDToInt(flow_id)
        ])
        response = cursor.fetchall()
        if not response:
            raise db.UnknownFlowError(client_id, flow_id)

        row, = response
        rdf_flow = self._FlowObjectFromRow(row)

        now = rdfvalue.RDFDatetime.Now()
        if rdf_flow.processing_on and rdf_flow.processing_deadline > now:
            raise ValueError(
                "Flow %s on client %s is already being processed." %
                (client_id, flow_id))
        update_query = (
            "UPDATE flows SET processing_on=%s, processing_since=%s, "
            "processing_deadline=%s WHERE client_id=%s and flow_id=%s")
        processing_deadline = now + processing_time
        process_id_string = utils.ProcessIdString()

        args = [
            process_id_string,
            mysql_utils.RDFDatetimeToMysqlString(now),
            mysql_utils.RDFDatetimeToMysqlString(processing_deadline),
            mysql_utils.ClientIDToInt(client_id),
            mysql_utils.FlowIDToInt(flow_id)
        ]
        cursor.execute(update_query, args)

        # This needs to happen after we are sure that the write has succeeded.
        rdf_flow.processing_on = process_id_string
        rdf_flow.processing_since = now
        rdf_flow.processing_deadline = processing_deadline
        return rdf_flow
Beispiel #11
0
 def ReadFlowForProcessing(self, client_id, flow_id, processing_time):
   """Marks a flow as being processed on this worker and returns it."""
   rdf_flow = self.ReadFlowObject(client_id, flow_id)
   now = rdfvalue.RDFDatetime.Now()
   if rdf_flow.processing_on and rdf_flow.processing_deadline > now:
     raise ValueError("Flow %s on client %s is already being processed." %
                      (client_id, flow_id))
   processing_deadline = now + processing_time
   process_id_string = utils.ProcessIdString()
   self.UpdateFlow(
       client_id,
       flow_id,
       processing_on=process_id_string,
       processing_since=now,
       processing_deadline=processing_deadline)
   rdf_flow.processing_on = process_id_string
   rdf_flow.processing_since = now
   rdf_flow.processing_deadline = processing_deadline
   return rdf_flow
Beispiel #12
0
  def LeaseCronJobs(self, cronjob_ids=None, lease_time=None):
    """Leases all available cron jobs."""
    leased_jobs = []

    now = rdfvalue.RDFDatetime.Now()
    expiration_time = now + lease_time

    for job in self.cronjobs.values():
      if cronjob_ids and job.cron_job_id not in cronjob_ids:
        continue
      existing_lease = self.cronjob_leases.get(job.cron_job_id)
      if existing_lease is None or existing_lease[0] < now:
        self.cronjob_leases[job.cron_job_id] = (expiration_time,
                                                utils.ProcessIdString())
        job = job.Copy()
        job.leased_until, job.leased_by = self.cronjob_leases[job.cron_job_id]
        leased_jobs.append(job)

    return leased_jobs
Beispiel #13
0
  def _LeaseMessageHandlerRequests(self, lease_time, limit):
    """Read and lease some outstanding message handler requests."""
    leased_requests = []

    now = rdfvalue.RDFDatetime.Now()
    zero = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0)
    expiration_time = now + lease_time

    leases = self.message_handler_leases
    for requests in itervalues(self.message_handler_requests):
      for r in itervalues(requests):
        existing_lease = leases.get(r.handler_name, {}).get(r.request_id, zero)
        if existing_lease < now:
          leases.setdefault(r.handler_name, {})[r.request_id] = expiration_time
          r.leased_until = expiration_time
          r.leased_by = utils.ProcessIdString()
          leased_requests.append(r)
          if len(leased_requests) >= limit:
            break

    return leased_requests
Beispiel #14
0
    def testMessageHandlerRequestLeasing(self):

        requests = [
            rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000",
                                              handler_name="Testhandler",
                                              request_id=i * 100,
                                              request=rdfvalue.RDFInteger(i))
            for i in range(10)
        ]
        lease_time = rdfvalue.Duration("5m")

        with test_lib.FakeTime(
                rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)):
            self.db.WriteMessageHandlerRequests(requests)

        leased = queue.Queue()
        self.db.RegisterMessageHandler(leased.put, lease_time, limit=5)

        got = []
        while len(got) < 10:
            try:
                l = leased.get(True, timeout=6)
            except queue.Empty:
                self.fail(
                    "Timed out waiting for messages, expected 10, got %d" %
                    len(got))
            self.assertLessEqual(len(l), 5)
            for m in l:
                self.assertEqual(m.leased_by, utils.ProcessIdString())
                self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now())
                self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now())
                m.leased_by = None
                m.leased_until = None
                m.timestamp = None
            got += l
        self.db.DeleteMessageHandlerRequests(got)

        got.sort(key=lambda req: req.request_id)
        self.assertEqual(requests, got)
Beispiel #15
0
  def LeaseClientActionRequests(self,
                                client_id,
                                lease_time=None,
                                limit=sys.maxsize):
    """Leases available client action requests for a client."""

    leased_requests = []

    now = rdfvalue.RDFDatetime.Now()
    expiration_time = now + lease_time
    process_id_str = utils.ProcessIdString()

    leases = self.client_action_request_leases
    # Can't use an iterator here since the dict might change when requests get
    # deleted.
    for key, request in sorted(self.client_action_requests.items()):
      if key[0] != client_id:
        continue

      existing_lease = leases.get(key)
      if not existing_lease or existing_lease[0] < now:
        if existing_lease:
          lease_count = existing_lease[-1] + 1
          if lease_count > db.Database.CLIENT_MESSAGES_TTL:
            self._DeleteClientActionRequest(*key)
            continue
        else:
          lease_count = 1

        leases[key] = (expiration_time, process_id_str, lease_count)
        request.leased_until = expiration_time
        request.leased_by = process_id_str
        request.ttl = db.Database.CLIENT_MESSAGES_TTL - lease_count
        leased_requests.append(request)
        if len(leased_requests) >= limit:
          break

    return leased_requests
Beispiel #16
0
  def LeaseClientMessages(self,
                          client_id,
                          lease_time=None,
                          limit=None,
                          cursor=None):
    """Leases available client messages for the client with the given id."""

    now = rdfvalue.RDFDatetime.Now()
    now_str = mysql_utils.RDFDatetimeToMysqlString(now)
    expiry = now + lease_time
    expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)
    proc_id_str = utils.ProcessIdString()
    client_id_int = mysql_utils.ClientIDToInt(client_id)

    query = ("UPDATE client_messages "
             "SET leased_until=%s, leased_by=%s "
             "WHERE client_id=%s AND "
             "(leased_until IS NULL OR leased_until < %s) "
             "LIMIT %s")
    args = [expiry_str, proc_id_str, client_id_int, now_str, limit]

    num_leased = cursor.execute(query, args)
    if num_leased == 0:
      return []

    query = ("SELECT message FROM client_messages "
             "WHERE client_id=%s AND leased_until=%s AND leased_by=%s")

    cursor.execute(query, [client_id_int, expiry_str, proc_id_str])

    ret = []
    for msg, in cursor.fetchall():
      message = rdf_flows.GrrMessage.FromSerializedString(msg)
      message.leased_by = proc_id_str
      message.leased_until = expiry
      ret.append(message)
    return sorted(ret, key=lambda msg: msg.task_id)
Beispiel #17
0
    def testClientMessageLeasing(self):

        client_id = self.InitializeClient()
        messages = [
            rdf_flows.GrrMessage(queue=client_id, generate_task_id=True)
            for _ in range(10)
        ]
        lease_time = rdfvalue.Duration("5m")

        self.db.WriteClientMessages(messages)

        t0 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000)
        with test_lib.FakeTime(t0):
            t0_expiry = t0 + lease_time
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t0_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

        t1 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 100)
        with test_lib.FakeTime(t1):
            t1_expiry = t1 + lease_time
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t1_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

            # Nothing left to lease.
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=2)

            self.assertEqual(len(leased), 0)

        read = self.db.ReadClientMessages(client_id)

        self.assertEqual(len(read), 10)
        for r in read:
            self.assertEqual(r.leased_by, utils.ProcessIdString())

        self.assertEqual(len([r for r in read if r.leased_until == t0_expiry]),
                         5)
        self.assertEqual(len([r for r in read if r.leased_until == t1_expiry]),
                         5)

        # Half the leases expired.
        t2 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 350)
        with test_lib.FakeTime(t2):
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time)

            self.assertEqual(len(leased), 5)

        # All of them expired.
        t3 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 10350)
        with test_lib.FakeTime(t3):
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time)

            self.assertEqual(len(leased), 10)