Пример #1
0
 def __decode_sth(self, sth_row):
     _, _, timestamp, serialized_sth, serialized_audit = sth_row
     audited_sth = client_pb2.AuditedSth()
     audited_sth.sth.ParseFromString(serialized_sth)
     audited_sth.sth.timestamp = timestamp
     audited_sth.audit.ParseFromString(serialized_audit)
     return audited_sth
Пример #2
0
 def test_store_sth_ignores_duplicate(self):
     self.db().add_log(LogDBTest.default_log)
     self.db().store_sth(LogDBTest.default_log.log_server,
                         LogDBTest.default_sth)
     duplicate_sth = client_pb2.AuditedSth()
     duplicate_sth.audit.status = client_pb2.VERIFY_ERROR
     self.db().store_sth(LogDBTest.default_log.log_server, duplicate_sth)
     read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
     self.assertTrue(read_sth)
     self.assertEqual(LogDBTest.default_sth, read_sth)
Пример #3
0
 def test_get_latest_sth_returns_latest(self):
     self.db().add_log(LogDBTest.default_log)
     self.db().store_sth(LogDBTest.default_log.log_server,
                         LogDBTest.default_sth)
     new_sth = client_pb2.AuditedSth()
     new_sth.CopyFrom(LogDBTest.default_sth)
     new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp - 1
     self.db().store_sth(LogDBTest.default_log.log_server, new_sth)
     read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
     self.assertIsNotNone(read_sth)
     self.assertEqual(LogDBTest.default_sth, read_sth)
Пример #4
0
    def test_scan_latest_sth_range_honours_limit(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", limit=1)
        sth = generator.next()
        # Returns most recent
        self.assertEqual(sth.sth.timestamp, 3)
        self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % 3)

        self.assertRaises(StopIteration, generator.next)
Пример #5
0
    def test_scan_latest_sth_range_honours_range(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", start=1, end=2)
        for i in range(2):
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, 2 - i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % (2 - i))

        self.assertRaises(StopIteration, generator.next)
Пример #6
0
    def test_scan_latest_sth_range_honours_log_server(self):
        for i in range(4):
            log = client_pb2.CtLogMetadata()
            log.log_server = "test-%d" % i
            self.db().add_log(log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth("test-%d" % i, sth)

        for i in range(4):
            generator = self.db().scan_latest_sth_range("test-%d" % i)
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)
Пример #7
0
    def test_scan_latest_sth_range_finds_all(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range(
            LogDBTest.default_log.log_server)
        for i in range(3, -1, -1):
            sth = generator.next()
            # Scan runs in descending timestamp order
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)

        self.assertRaises(StopIteration, generator.next)
Пример #8
0
    def test_get_latest_sth_honours_log_server(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        new_sth = client_pb2.AuditedSth()
        new_sth.CopyFrom(LogDBTest.default_sth)
        new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp + 1

        new_log = client_pb2.CtLogMetadata()
        new_log.log_server = "test2"
        self.db().add_log(new_log)

        new_sth.sth.sha256_root_hash = "hash2"
        self.db().store_sth(new_log.log_server, new_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertIsNotNone(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)
Пример #9
0
 def __get_audited_sth(self, sth, verify_status):
     audited_sth = client_pb2.AuditedSth()
     audited_sth.sth.CopyFrom(sth)
     audited_sth.audit.status = verify_status
     return audited_sth
Пример #10
0
    def _update_sth(self):
        """Get a new candidate STH. If update succeeds, stores the new STH as
        pending. Does nothing if there is already a pending
        STH.
        Returns: True if the update succeeded."""
        if self.__state.HasField("pending_sth"):
            return True
        logging.info("Fetching new STH")
        try:
            sth_response = self.__client.get_sth()
            logging.info("Got new STH: %s" % sth_response)
        except (log_client.HTTPError, log_client.InvalidResponseError) as e:
            logging.error("get-sth from %s failed: %s" % (self.servername, e))
            return False

        # If we got the same response as last time, do nothing.
        # If we got an older response than last time, return False.
        # (It is not necessarily an inconsistency - the log could be out of
        # sync - but we should not rewind to older data.)
        #
        # The client should always return an STH but best eliminate the
        # None == None case explicitly by only shortcutting the verification
        # if we already have a verified STH.
        if self.__state.HasField("verified_sth"):
            if sth_response == self.__state.verified_sth:
                logging.info("Ignoring already-verified STH: %s" %
                             sth_response)
                return True
            elif (sth_response.timestamp <
                  self.__state.verified_sth.timestamp):
                logging.error("Rejecting received STH: timestamp is older "
                              "than current verified STH: %s vs %s " %
                              (sth_response, self.__state.verified_sth))
                return False

        try:
            # Given that we now only store verified STHs, the audit info here
            # is not all that useful.
            # TODO(ekasper): we should be tracking consistency instead.
            self.__verifier.verify_sth(sth_response)
            audited_sth = client_pb2.AuditedSth()
            audited_sth.sth.CopyFrom(sth_response)
            audited_sth.audit.status = client_pb2.VERIFIED
            self.__db.store_sth(self.servername, audited_sth)
        except (error.EncodingError, error.VerifyError) as e:
            logging.error("Invalid STH: %s" % sth_response)
            return False

        # Verify consistency to catch the log trying to trick us
        # into rewinding the tree.
        try:
            self._verify_consistency(self.__state.verified_sth, sth_response)
        except error.VerifyError:
            return False

        # We now have a valid STH that is newer than our current STH: we should
        # be holding on to it until we have downloaded and verified data under
        # its signature.
        logging.info("STH verified, updating state.")
        self._set_pending_sth(sth_response)
        return True
Пример #11
0
class LogDBTest(object):
    """All LogDB tests should derive from this class as well as
    unittest.TestCase."""
    __metaclass__ = abc.ABCMeta

    # Set up a default fake test log server and STH.
    default_log = client_pb2.CtLogMetadata()
    default_log.log_server = "test"
    default_log.log_id = "somekeyid"
    default_log.public_key_info.type = client_pb2.KeyInfo.ECDSA
    default_log.public_key_info.pem_key = "base64encodedkey"

    default_sth = client_pb2.AuditedSth()
    default_sth.sth.timestamp = 1234
    default_sth.sth.sha256_root_hash = "base64hash"
    default_sth.audit.status = client_pb2.VERIFIED

    @abc.abstractmethod
    def db(self):
        """Derived classes must override to initialize a database."""
        pass

    def test_add_log(self):
        self.db().add_log(LogDBTest.default_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, LogDBTest.default_log)
        self.assertRaises(StopIteration, generator.next)

    def test_update_log(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)

        new_log = client_pb2.CtLogMetadata()
        new_log.CopyFrom(LogDBTest.default_log)
        new_log.public_key_info.pem_key = "newkey"
        self.db().update_log(new_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, new_log)
        self.assertRaises(StopIteration, generator.next)

        # Should still be able to access STHs after updating log metadata
        read_sth = self.db().get_latest_sth(new_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_update_log_adds_log(self):
        self.db().update_log(LogDBTest.default_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, LogDBTest.default_log)
        self.assertRaises(StopIteration, generator.next)

    def test_store_sth(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_store_sth_ignores_duplicate(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        duplicate_sth = client_pb2.AuditedSth()
        duplicate_sth.audit.status = client_pb2.VERIFY_ERROR
        self.db().store_sth(LogDBTest.default_log.log_server, duplicate_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_log_not_found_raises(self):
        self.assertRaises(database.KeyError,
                          self.db().store_sth,
                          LogDBTest.default_log.log_server,
                          LogDBTest.default_sth)

    def test_get_latest_sth_returns_latest(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        new_sth = client_pb2.AuditedSth()
        new_sth.CopyFrom(LogDBTest.default_sth)
        new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp - 1
        self.db().store_sth(LogDBTest.default_log.log_server, new_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertIsNotNone(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_get_latest_sth_returns_none_if_empty(self):
        self.db().add_log(LogDBTest.default_log)
        self.assertIsNone(self.db().get_latest_sth(
            LogDBTest.default_log.log_server))

    def test_get_latest_sth_honours_log_server(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        new_sth = client_pb2.AuditedSth()
        new_sth.CopyFrom(LogDBTest.default_sth)
        new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp + 1

        new_log = client_pb2.CtLogMetadata()
        new_log.log_server = "test2"
        self.db().add_log(new_log)

        new_sth.sth.sha256_root_hash = "hash2"
        self.db().store_sth(new_log.log_server, new_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertIsNotNone(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_scan_latest_sth_range_finds_all(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range(
            LogDBTest.default_log.log_server)
        for i in range(3, -1, -1):
            sth = generator.next()
            # Scan runs in descending timestamp order
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)

        self.assertRaises(StopIteration, generator.next)

    def test_scan_latest_sth_range_honours_log_server(self):
        for i in range(4):
            log = client_pb2.CtLogMetadata()
            log.log_server = "test-%d" % i
            self.db().add_log(log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth("test-%d" % i, sth)

        for i in range(4):
            generator = self.db().scan_latest_sth_range("test-%d" % i)
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)

    def test_scan_latest_sth_range_honours_range(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", start=1, end=2)
        for i in range(2):
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, 2 - i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % (2 - i))

        self.assertRaises(StopIteration, generator.next)

    def test_scan_latest_sth_range_honours_limit(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", limit=1)
        sth = generator.next()
        # Returns most recent
        self.assertEqual(sth.sth.timestamp, 3)
        self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % 3)

        self.assertRaises(StopIteration, generator.next)