예제 #1
0
    def LeaseCronJobs(self, cronjob_ids=None, lease_time=None, cursor=None):
        """Leases all available cron jobs."""
        now = rdfvalue.RDFDatetime.Now()
        now_str = mysql_utils.RDFDatetimeToMysqlString(now)
        expiry_str = mysql_utils.RDFDatetimeToMysqlString(now + lease_time)
        id_str = utils.ProcessIdString()

        query = ("UPDATE cron_jobs "
                 "SET leased_until=%s, leased_by=%s "
                 "WHERE (leased_until IS NULL OR leased_until < %s)")
        args = [expiry_str, id_str, now_str]

        if cronjob_ids:
            query += " AND job_id in (%s)" % ", ".join(
                ["%s"] * len(cronjob_ids))
            args += cronjob_ids

        updated = cursor.execute(query, args)

        if updated == 0:
            return []

        cursor.execute(
            "SELECT job, create_time, disabled, "
            "last_run_status, last_run_time, current_run_id, state, "
            "leased_until, leased_by "
            "FROM cron_jobs WHERE leased_until=%s AND leased_by=%s",
            [expiry_str, id_str])
        return [self._CronjobFromRow(row) for row in cursor.fetchall()]
예제 #2
0
파일: mem_flows.py 프로젝트: qsdj/grr
    def LeaseClientMessages(self,
                            client_id,
                            lease_time=None,
                            limit=sys.maxsize):
        """Leases available client messages for the client with the given id."""
        leased_messages = []

        now = rdfvalue.RDFDatetime.Now()
        expiration_time = now + lease_time
        process_id_str = utils.ProcessIdString()

        leases = self.client_message_leases
        for msgs_by_id in self.client_messages.values():
            for msg in msgs_by_id.values():
                if msg.queue.Split()[0] != client_id:
                    continue

                existing_lease = leases.get(msg.task_id)
                if not existing_lease or existing_lease[0] < now:
                    leases[msg.task_id] = (expiration_time, process_id_str)
                    msg.leased_until = expiration_time
                    msg.leased_by = process_id_str
                    leased_messages.append(msg)
                    if len(leased_messages) >= limit:
                        break

        return leased_messages
예제 #3
0
    def _LeaseMessageHandlerRequests(self, lease_time, limit, cursor=None):
        """Leases a number of message handler requests up to the indicated limit."""

        now = rdfvalue.RDFDatetime.Now()
        now_str = mysql_utils.RDFDatetimeToMysqlString(now)

        expiry = now + lease_time
        expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)

        query = ("UPDATE message_handler_requests "
                 "SET leased_until=%s, leased_by=%s "
                 "WHERE leased_until IS NULL OR leased_until < %s "
                 "LIMIT %s")

        id_str = utils.ProcessIdString()
        args = (expiry_str, id_str, now_str, limit)
        updated = cursor.execute(query, args)

        if updated == 0:
            return []

        cursor.execute(
            "SELECT timestamp, request FROM message_handler_requests "
            "WHERE leased_by=%s AND leased_until=%s LIMIT %s",
            (id_str, expiry_str, updated))
        res = []
        for timestamp, request in cursor.fetchall():
            req = rdf_objects.MessageHandlerRequest.FromSerializedString(
                request)
            req.timestamp = mysql_utils.MysqlToRDFDatetime(timestamp)
            req.leased_until = expiry
            req.leased_by = id_str
            res.append(req)

        return res
예제 #4
0
  def LeaseCronJobs(self, cronjob_ids=None, lease_time=None):
    """Leases all available cron jobs."""
    leased_jobs = []

    now = rdfvalue.RDFDatetime.Now()
    expiration_time = now + lease_time

    for job in self.cronjobs.values():
      if cronjob_ids and job.job_id not in cronjob_ids:
        continue
      existing_lease = self.cronjob_leases.get(job.job_id)
      if existing_lease is None or existing_lease[0] < now:
        self.cronjob_leases[job.job_id] = (expiration_time,
                                           utils.ProcessIdString())
        job = job.Copy()
        job.leased_until, job.leased_by = self.cronjob_leases[job.job_id]
        leased_jobs.append(job)

    return leased_jobs
예제 #5
0
  def _LeaseMessageHandlerRequests(self, lease_time, limit):
    """Read and lease some outstanding message handler requests."""
    leased_requests = []

    now = rdfvalue.RDFDatetime.Now()
    zero = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0)
    expiration_time = now + lease_time

    leases = self.message_handler_leases
    for requests in self.message_handler_requests.values():
      for r in requests.values():
        existing_lease = leases.get(r.handler_name, {}).get(r.request_id, zero)
        if existing_lease < now:
          leases.setdefault(r.handler_name, {})[r.request_id] = expiration_time
          r.leased_until = expiration_time
          r.leased_by = utils.ProcessIdString()
          leased_requests.append(r)
          if len(leased_requests) >= limit:
            break

    return leased_requests
예제 #6
0
    def LeaseClientMessages(self,
                            client_id,
                            lease_time=None,
                            limit=None,
                            cursor=None):
        """Leases available client messages for the client with the given id."""

        now = rdfvalue.RDFDatetime.Now()
        now_str = mysql_utils.RDFDatetimeToMysqlString(now)
        expiry = now + lease_time
        expiry_str = mysql_utils.RDFDatetimeToMysqlString(expiry)
        proc_id_str = utils.ProcessIdString()
        client_id_int = mysql_utils.ClientIDToInt(client_id)

        query = ("UPDATE client_messages "
                 "SET leased_until=%s, leased_by=%s "
                 "WHERE client_id=%s AND "
                 "(leased_until IS NULL OR leased_until < %s) "
                 "LIMIT %s")
        args = [expiry_str, proc_id_str, client_id_int, now_str, limit]

        num_leased = cursor.execute(query, args)
        if num_leased == 0:
            return []

        query = ("SELECT message FROM client_messages "
                 "WHERE client_id=%s AND leased_until=%s AND leased_by=%s")

        cursor.execute(query, [client_id_int, expiry_str, proc_id_str])

        ret = []
        for msg, in cursor.fetchall():
            message = rdf_flows.GrrMessage.FromSerializedString(msg)
            message.leased_by = proc_id_str
            message.leased_until = expiry
            ret.append(message)
        return ret
예제 #7
0
  def testMessageHandlerRequestLeasing(self):

    requests = [
        rdf_objects.MessageHandlerRequest(
            client_id="C.1000000000000000",
            handler_name="Testhandler",
            request_id=i * 100,
            request=rdfvalue.RDFInteger(i)) for i in range(10)
    ]
    lease_time = rdfvalue.Duration("5m")

    with test_lib.FakeTime(rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)):
      self.db.WriteMessageHandlerRequests(requests)

    leased = Queue.Queue()
    self.db.RegisterMessageHandler(leased.put, lease_time, limit=5)

    got = []
    while len(got) < 10:
      try:
        l = leased.get(True, timeout=6)
      except Queue.Empty:
        self.fail(
            "Timed out waiting for messages, expected 10, got %d" % len(got))
      self.assertLessEqual(len(l), 5)
      for m in l:
        self.assertEqual(m.leased_by, utils.ProcessIdString())
        self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now())
        self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now())
        m.leased_by = None
        m.leased_until = None
        m.timestamp = None
      got += l
    self.db.DeleteMessageHandlerRequests(got)

    got.sort(key=lambda req: req.request_id)
    self.assertEqual(requests, got)
예제 #8
0
파일: db_flows_test.py 프로젝트: qsdj/grr
    def testClientMessageLeasing(self):

        client_id = self.InitializeClient()
        messages = [
            rdf_flows.GrrMessage(queue=client_id, generate_task_id=True)
            for _ in range(10)
        ]
        lease_time = rdfvalue.Duration("5m")

        self.db.WriteClientMessages(messages)

        t0 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000)
        with test_lib.FakeTime(t0):
            t0_expiry = t0 + lease_time
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t0_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

        t1 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 100)
        with test_lib.FakeTime(t1):
            t1_expiry = t1 + lease_time
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t1_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

            # Nothing left to lease.
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time,
                                                 limit=2)

            self.assertEqual(len(leased), 0)

        read = self.db.ReadClientMessages(client_id)

        self.assertEqual(len(read), 10)
        for r in read:
            self.assertEqual(r.leased_by, utils.ProcessIdString())

        self.assertEqual(len([r for r in read if r.leased_until == t0_expiry]),
                         5)
        self.assertEqual(len([r for r in read if r.leased_until == t1_expiry]),
                         5)

        # Half the leases expired.
        t2 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 350)
        with test_lib.FakeTime(t2):
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time)

            self.assertEqual(len(leased), 5)

        # All of them expired.
        t3 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 10350)
        with test_lib.FakeTime(t3):
            leased = self.db.LeaseClientMessages(client_id,
                                                 lease_time=lease_time)

            self.assertEqual(len(leased), 10)
예제 #9
0
    def testMessageHandlerRequestLeasing(self):

        requests = [
            rdf_objects.MessageHandlerRequest(client_id="C.1000000000000000",
                                              handler_name="Testhandler",
                                              request_id=i * 100,
                                              request=rdfvalue.RDFInteger(i))
            for i in range(10)
        ]
        lease_time = rdfvalue.Duration("5m")

        with test_lib.FakeTime(
                rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000)):
            self.db.WriteMessageHandlerRequests(requests)

        t0 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000)
        with test_lib.FakeTime(t0):
            t0_expiry = t0 + lease_time
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t0_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

        t1 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 100)
        with test_lib.FakeTime(t1):
            t1_expiry = t1 + lease_time
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=5)

            self.assertEqual(len(leased), 5)

            for request in leased:
                self.assertEqual(request.leased_until, t1_expiry)
                self.assertEqual(request.leased_by, utils.ProcessIdString())

            # Nothing left to lease.
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time,
                                                         limit=2)

            self.assertEqual(len(leased), 0)

        read = self.db.ReadMessageHandlerRequests()

        self.assertEqual(len(read), 10)
        for r in read:
            self.assertEqual(r.leased_by, utils.ProcessIdString())

        self.assertEqual(len([r for r in read if r.leased_until == t0_expiry]),
                         5)
        self.assertEqual(len([r for r in read if r.leased_until == t1_expiry]),
                         5)

        # Half the leases expired.
        t2 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 350)
        with test_lib.FakeTime(t2):
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time)

            self.assertEqual(len(leased), 5)

        # All of them expired.
        t3 = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(100000 + 10350)
        with test_lib.FakeTime(t3):
            leased = self.db.LeaseMessageHandlerRequests(lease_time=lease_time)

            self.assertEqual(len(leased), 10)