class StateKeeperTest(unittest.TestCase):
    _DEFAULT_STATE = client_pb2.MonitorState()
    # fill with some data
    _DEFAULT_STATE.verified_sth.timestamp = 1234
    _DEFAULT_STATE.pending_sth.tree_size = 5678

    def test_read_write(self):
        handle, state_file = tempfile.mkstemp()
        os.close(handle)
        state_keeper = state.StateKeeper(state_file)
        state_keeper.write(self._DEFAULT_STATE)
        self.assertEqual(self._DEFAULT_STATE,
                         state_keeper.read(client_pb2.MonitorState))
        os.remove(state_file)

    def test_read_no_such_file(self):
        temp_dir = tempfile.mkdtemp()
        state_keeper = state.StateKeeper(temp_dir + "/foo")
        self.assertRaises(state.FileNotFoundError, state_keeper.read,
                          client_pb2.MonitorState)

    def test_read_corrupt_file(self):
        handle, state_file = tempfile.mkstemp()
        os.write(handle, "wibble")
        os.close(handle)
        state_keeper = state.StateKeeper(state_file)
        self.assertRaises(state.CorruptStateError, state_keeper.read,
                          client_pb2.MonitorState)
        os.remove(state_file)
예제 #2
0
    def __init__(self, client, verifier, hasher, db, cert_db, log_key,
                 state_keeper):
        self.__client = client
        self.__verifier = verifier
        self.__hasher = hasher
        self.__db = db
        self.__state_keeper = state_keeper

        # TODO(ekasper): once consistency checks are in place, also load/store
        # Merkle tree info.
        # Depends on: Merkle trees implemented in Python.
        self.__state = client_pb2.MonitorState()
        self.__report = aggregated_reporter.AggregatedCertificateReport(
            (text_reporter.TextCertificateReport(),
             db_reporter.CertDBCertificateReport(cert_db, log_key)))
        try:
            self.__state = self.__state_keeper.read(client_pb2.MonitorState)
        except state.FileNotFoundError:
            # TODO(ekasper): initialize state file with a setup script, so we
            # can raise with certainty when it's not found.
            logging.warning("Monitor state file not found, assuming first "
                            "run.")
        else:
            if not self.__state.HasField("verified_sth"):
                logging.warning(
                    "No verified monitor state, assuming first run.")

        # load compact merkle tree state from the monitor state
        self._verified_tree = merkle.CompactMerkleTree(hasher)
        self._unverified_tree = merkle.CompactMerkleTree(hasher)
        self._verified_tree.load(self.__state.verified_tree)
        self._unverified_tree.load(self.__state.unverified_tree)
예제 #3
0
        def check_state(result):
            # Check that we kept the state...
            expected_state = client_pb2.MonitorState()
            expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
            self.verify_state(expected_state)

            # ...and wrote no entries.
            self.assertFalse(self.temp_db.store_entries.called)
            self.check_db_state_after_successful_updates(0)
예제 #4
0
        def check_state(result):
            self.assertTrue(m._verify_consistency.called)
            args, _ = m._verify_consistency.call_args
            self.assertTrue(args[0].timestamp < args[1].timestamp)

            # Check that we kept the state.
            expected_state = client_pb2.MonitorState()
            expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
            self.verify_state(expected_state)
    def test_update_sth_fails_on_client_error(self):
        client = FakeLogClient(self._NEW_STH)
        client.get_sth = mock.Mock(side_effect=log_client.HTTPError("Boom!"))

        m = self.create_monitor(client)
        self.assertFalse(m._update_sth())

        # Check that we kept the state.
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.verify_state(expected_state)
    def test_update_sth_fails_for_invalid_sth(self):
        client = FakeLogClient(self._NEW_STH)
        self.verifier.verify_sth.side_effect = error.VerifyError("Boom!")

        m = self.create_monitor(client)
        self.assertFalse(m._update_sth())

        # Check that we kept the state.
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.verify_state(expected_state)
    def test_update_sth(self):
        client = FakeLogClient(self._NEW_STH)

        m = self.create_monitor(client)
        self.assertTrue(m._update_sth())

        # Check that we updated the state.
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        expected_state.pending_sth.CopyFrom(self._NEW_STH)
        merkle.CompactMerkleTree().save(expected_state.verified_tree)
        self.verify_state(expected_state)
예제 #8
0
 def check_state(result):
     # Check that we updated the state.
     expected_state = client_pb2.MonitorState()
     expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
     expected_state.pending_sth.CopyFrom(self._NEW_STH)
     merkle.CompactMerkleTree().save(expected_state.verified_tree)
     merkle.CompactMerkleTree().save(expected_state.unverified_tree)
     self.verify_state(expected_state)
     audited_sths = list(self.db.scan_latest_sth_range("log_server"))
     self.assertEqual(audited_sths[0].audit.status, client_pb2.VERIFIED)
     self.assertEqual(audited_sths[1].audit.status,
                      client_pb2.UNVERIFIED)
     self.assertEqual(len(audited_sths), 2)
예제 #9
0
 def _set_pending_sth(self, new_sth):
     """Set pending_sth from new_sth, or just verified_sth if not bigger."""
     if new_sth.tree_size < self.__state.verified_sth.tree_size:
         raise ValueError("pending size must be >= verified size")
     if new_sth.timestamp <= self.__state.verified_sth.timestamp:
         raise ValueError("pending time must be > verified time")
     new_state = client_pb2.MonitorState()
     new_state.CopyFrom(self.__state)
     if new_sth.tree_size > self.__state.verified_sth.tree_size:
         new_state.pending_sth.CopyFrom(new_sth)
     else:
         new_state.verified_sth.CopyFrom(new_sth)
     self.__update_state(new_state)
    def test_update_sth_fails_for_inconsistent_sth(self):
        client = FakeLogClient(self._NEW_STH)
        # The STH is in fact OK but fake failure.
        self.verifier.verify_sth_consistency.side_effect = (
            error.ConsistencyError("Boom!"))

        m = self.create_monitor(client)
        self.assertFalse(m._update_sth())

        # Check that we kept the state.
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.verify_state(expected_state)
예제 #11
0
 def check_state(result):
     # Check that we kept the state.
     expected_state = client_pb2.MonitorState()
     expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
     self.verify_state(expected_state)
     audited_sths = list(self.db.scan_latest_sth_range("log_server"))
     self.assertEqual(len(audited_sths), 2)
     self.assertEqual(audited_sths[0].audit.status,
                      client_pb2.VERIFY_ERROR)
     self.assertEqual(audited_sths[1].audit.status,
                      client_pb2.UNVERIFIED)
     for audited_sth in audited_sths:
         self.assertEqual(self._DEFAULT_STH.sha256_root_hash,
                          audited_sth.sth.sha256_root_hash)
    def test_update_sth_fails_for_stale_sth(self):
        sth = client_pb2.SthResponse()
        sth.CopyFrom(self._DEFAULT_STH)
        sth.tree_size -= 1
        sth.timestamp -= 1
        client = FakeLogClient(sth)

        m = self.create_monitor(client)
        self.assertFalse(m._update_sth())

        # Check that we kept the state.
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.verify_state(expected_state)
    def test_update_no_new_entries(self):
        client = FakeLogClient(self._DEFAULT_STH)

        self.temp_db.store_entries = mock.Mock()

        m = self.create_monitor(client)
        self.assertTrue(m.update())

        # Check that we kept the state...
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.verify_state(expected_state)

        # ...and wrote no entries.
        self.assertFalse(self.temp_db.store_entries.called)
예제 #14
0
 def _set_verified_tree(self, new_tree):
     """Set verified_tree and maybe move pending_sth to verified_sth."""
     self.__verified_tree = new_tree
     old_state = self.__state
     new_state = client_pb2.MonitorState()
     new_state.CopyFrom(self.__state)
     assert old_state.pending_sth.tree_size >= new_tree.tree_size
     if old_state.pending_sth.tree_size == new_tree.tree_size:
         # all pending entries retrieved
         # already did consistency checks so this should always be true
         assert (old_state.pending_sth.sha256_root_hash ==
                 self.__verified_tree.root_hash())
         new_state.verified_sth.CopyFrom(old_state.pending_sth)
         new_state.ClearField("pending_sth")
     self.__update_state(new_state)
예제 #15
0
        def check_state(result):
            # Check that we wrote the state...
            expected_state = client_pb2.MonitorState()
            expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.verified_tree)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.unverified_tree)
            self.verify_state(expected_state)

            self.verify_tmp_data(0, self._DEFAULT_STH.tree_size - 1)
            self.check_db_state_after_successful_updates(1)
            for audited_sth in list(
                    self.db.scan_latest_sth_range("log_server")):
                self.assertEqual(self._DEFAULT_STH, audited_sth.sth)
    def test_first_update(self):
        client = FakeLogClient(self._DEFAULT_STH)

        self.state_keeper.state = None
        m = self.create_monitor(client)
        m._compute_projected_sth = self._DEFAULT_STH_compute_projected
        self.assertTrue(m.update())

        # Check that we wrote the state...
        expected_state = client_pb2.MonitorState()
        expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        m._compute_projected_sth.dummy_tree.save(expected_state.verified_tree)
        self.verify_state(expected_state)

        self.verify_tmp_data(0, self._DEFAULT_STH.tree_size - 1)
    def setUp(self):
        if not FLAGS.verbose_tests:
            logging.disable(logging.CRITICAL)
        self.db = sqlite_log_db.SQLiteLogDB(
            sqlitecon.SQLiteConnectionManager(":memory:", keepalive=True))
        self.temp_db = sqlite_temp_db.SQLiteTempDB(
            sqlitecon.SQLiteConnectionManager(":memory:", keepalive=True))

        default_state = client_pb2.MonitorState()
        default_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.state_keeper = InMemoryStateKeeper(default_state)
        self.verifier = mock.Mock()
        self.hasher = merkle.TreeHasher()

        # Make sure the DB knows about the default log server.
        log = client_pb2.CtLogMetadata()
        log.log_server = "log_server"
        self.db.add_log(log)
예제 #18
0
    def setUp(self):
        if not FLAGS.verbose_tests:
            logging.disable(logging.CRITICAL)
        self.db = sqlite_log_db.SQLiteLogDB(
            sqlitecon.SQLiteConnectionManager(":memory:", keepalive=True))
        self.temp_db = sqlite_temp_db.SQLiteTempDB(
            sqlitecon.SQLiteConnectionManager(":memory:", keepalive=True))
        # We can't simply use DB in memory with keepalive True, because different
        # thread is writing to the database which results in an sqlite exception.
        self.cert_db = mock.MagicMock()

        default_state = client_pb2.MonitorState()
        default_state.verified_sth.CopyFrom(self._DEFAULT_STH)
        self.state_keeper = InMemoryStateKeeper(default_state)
        self.verifier = mock.Mock()
        self.hasher = merkle.TreeHasher()

        # Make sure the DB knows about the default log server.
        log = client_pb2.CtLogMetadata()
        log.log_server = "log_server"
        self.db.add_log(log)
예제 #19
0
 def check_state(result):
     # Check that we kept the state.
     expected_state = client_pb2.MonitorState()
     expected_state.verified_sth.CopyFrom(self._DEFAULT_STH)
     self.verify_state(expected_state)
     self.check_db_state_after_successful_updates(0)
예제 #20
0
 def _update_unverified_data(self, unverified_tree):
     self._unverified_tree = unverified_tree
     new_state = client_pb2.MonitorState()
     new_state.CopyFrom(self.__state)
     self.__update_state(new_state)
예제 #21
0
class MonitorTest(unittest.TestCase):
    _DEFAULT_STH = client_pb2.SthResponse()
    _DEFAULT_STH.timestamp = 2000
    _DEFAULT_STH.tree_size = 10
    _DEFAULT_STH.tree_head_signature = "sig"
    _DEFAULT_STH_compute_projected = dummy_compute_projected_sth(_DEFAULT_STH)

    _NEW_STH = client_pb2.SthResponse()
    _NEW_STH.timestamp = 3000
    _NEW_STH.tree_size = _DEFAULT_STH.tree_size + 10
    _NEW_STH.tree_head_signature = "sig2"
    _NEW_STH_compute_projected = dummy_compute_projected_sth(_NEW_STH)

    _DEFAULT_STATE = client_pb2.MonitorState()
    _DEFAULT_STATE.verified_sth.CopyFrom(_DEFAULT_STH)
    _DEFAULT_STH_compute_projected.dummy_tree.save(
        _DEFAULT_STATE.unverified_tree)
    _DEFAULT_STH_compute_projected.dummy_tree.save(
        _DEFAULT_STATE.verified_tree)

    def setUp(self):
        if not FLAGS.verbose_tests:
            logging.disable(logging.CRITICAL)
        self.db = sqlite_log_db.SQLiteLogDB(
            sqlitecon.SQLiteConnectionManager(":memory:", keepalive=True))
        # We can't simply use DB in memory with keepalive True, because different
        # thread is writing to the database which results in an sqlite exception.
        self.cert_db = mock.MagicMock()

        self.state_keeper = InMemoryStateKeeper(
            copy.deepcopy(self._DEFAULT_STATE))
        self.verifier = mock.Mock()
        self.hasher = merkle.TreeHasher()

        # Make sure the DB knows about the default log server.
        log = client_pb2.CtLogMetadata()
        log.log_server = "log_server"
        self.db.add_log(log)

    def verify_state(self, expected_state):
        if self.state_keeper.state != expected_state:
            state_diff = difflib.unified_diff(
                str(expected_state).splitlines(),
                str(self.state_keeper.state).splitlines(),
                fromfile="expected",
                tofile="actual",
                lineterm="",
                n=5)

            raise unittest.FailTest("State is incorrect\n" +
                                    "\n".join(state_diff))

    def verify_tmp_data(self, start, end):
        # TODO: we are no longer using the temp db
        # all the callsites should be updated to test the main db instead
        pass

    def create_monitor(self, client, skip_scan_entry=True):
        m = monitor.Monitor(client, self.verifier, self.hasher, self.db,
                            self.cert_db, 7, self.state_keeper)
        if m:
            m._scan_entries = mock.Mock()
        return m

    def check_db_state_after_successful_updates(self, number_of_updates):
        audited_sths = list(self.db.scan_latest_sth_range("log_server"))
        for index, audited_sth in enumerate(audited_sths):
            if index % 2 != 0:
                self.assertEqual(client_pb2.UNVERIFIED,
                                 audited_sth.audit.status)
            else:
                self.assertEqual(client_pb2.VERIFIED, audited_sth.audit.status)
        self.assertEqual(len(audited_sths), number_of_updates * 2)

    def test_update(self):
        client = FakeLogClient(self._NEW_STH)

        m = self.create_monitor(client)
        m._compute_projected_sth_from_tree = self._NEW_STH_compute_projected

        def check_state(result):
            # Check that we wrote the state...
            expected_state = client_pb2.MonitorState()
            expected_state.verified_sth.CopyFrom(self._NEW_STH)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.verified_tree)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.unverified_tree)
            self.verify_state(expected_state)

            self.verify_tmp_data(self._DEFAULT_STH.tree_size,
                                 self._NEW_STH.tree_size - 1)
            self.check_db_state_after_successful_updates(1)
            for audited_sth in self.db.scan_latest_sth_range(m.servername):
                self.assertEqual(self._NEW_STH, audited_sth.sth)

        return m.update().addCallback(self.assertTrue).addCallback(check_state)

    def test_first_update(self):
        client = FakeLogClient(self._DEFAULT_STH)

        self.state_keeper.state = None
        m = self.create_monitor(client)
        m._compute_projected_sth_from_tree = self._DEFAULT_STH_compute_projected

        def check_state(result):
            # Check that we wrote the state...
            self.verify_state(self._DEFAULT_STATE)

            self.verify_tmp_data(0, self._DEFAULT_STH.tree_size - 1)
            self.check_db_state_after_successful_updates(1)
            for audited_sth in self.db.scan_latest_sth_range(m.servername):
                self.assertEqual(self._DEFAULT_STH, audited_sth.sth)

        d = m.update().addCallback(self.assertTrue).addCallback(check_state)
        return d

    def test_update_no_new_entries(self):
        client = FakeLogClient(self._DEFAULT_STH)

        m = self.create_monitor(client)
        d = m.update()
        d.addCallback(self.assertTrue)

        def check_state(result):
            # Check that we kept the state...
            self.verify_state(self._DEFAULT_STATE)

            # ...and wrote no entries.
            self.check_db_state_after_successful_updates(0)

        d.addCallback(check_state)
        return d

    def test_update_recovery(self):
        client = FakeLogClient(self._NEW_STH)

        # Setup initial state to be as though an update had failed part way
        # through.
        initial_state = copy.deepcopy(self._DEFAULT_STATE)
        initial_state.pending_sth.CopyFrom(self._NEW_STH)
        self._NEW_STH_compute_projected.dummy_tree.save(
            initial_state.unverified_tree)
        self.state_keeper.write(initial_state)

        m = self.create_monitor(client)
        m._compute_projected_sth_from_tree = self._NEW_STH_compute_projected

        d = m.update()
        d.addCallback(self.assertTrue)

        def check_state(result):
            # Check that we wrote the state...
            expected_state = copy.deepcopy(initial_state)
            expected_state.ClearField("pending_sth")
            expected_state.verified_sth.CopyFrom(self._NEW_STH)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.verified_tree)
            m._compute_projected_sth_from_tree.dummy_tree.save(
                expected_state.unverified_tree)
            self.verify_state(expected_state)

            self.check_db_state_after_successful_updates(1)
            for audited_sth in self.db.scan_latest_sth_range(m.servername):
                self.assertEqual(self._NEW_STH, audited_sth.sth)

        d.addCallback(check_state)
        return d

    def test_update_rolls_back_unverified_tree_on_scan_error(self):
        client = FakeLogClient(self._NEW_STH)

        m = self.create_monitor(client)
        m._compute_projected_sth_from_tree = self._NEW_STH_compute_projected
        m._scan_entries = mock.Mock(side_effect=ValueError("Boom!"))

        def check_state(result):
            # The changes to the unverified tree should have been discarded,
            # so that entries are re-fetched and re-consumed next time.
            expected_state = copy.deepcopy(self._DEFAULT_STATE)
            expected_state.pending_sth.CopyFrom(self._NEW_STH)
            self.verify_state(expected_state)
            # The new STH should have been verified prior to the error.
            audited_sths = list(self.db.scan_latest_sth_range(m.servername))
            self.assertEqual(len(audited_sths), 2)
            self.assertEqual(audited_sths[0].audit.status, client_pb2.VERIFIED)
            self.assertEqual(audited_sths[1].audit.status,
                             client_pb2.UNVERIFIED)

        return m.update().addCallback(
            self.assertFalse).addCallback(check_state)

    def test_update_call_sequence(self):
        # Test that update calls update_sth and update_entries in sequence,
        # and bails on first error, so we can test each of them separately.
        # Each of these functions checks if functions were properly called
        # and runs step in sequence of updates.
        def check_calls_sth_fails(result):
            m._update_sth.assert_called_once_with()
            m._update_entries.assert_called_once_with()

            m._update_sth.reset_mock()
            m._update_entries.reset_mock()
            m._update_sth.return_value = copy.deepcopy(d_false)
            return m.update().addCallback(self.assertFalse)

        def check_calls_entries_fail(result):
            m._update_sth.assert_called_once_with()
            self.assertFalse(m._update_entries.called)

            m._update_sth.reset_mock()
            m._update_entries.reset_mock()
            m._update_sth.return_value = copy.deepcopy(d_true)
            m._update_entries.return_value = copy.deepcopy(d_false)
            return m.update().addCallback(self.assertFalse)

        def check_calls_assert_last_calls(result):
            m._update_sth.assert_called_once_with()
            m._update_entries.assert_called_once_with()

        client = FakeLogClient(self._DEFAULT_STH)

        m = self.create_monitor(client)
        d_true = defer.Deferred()
        d_true.callback(True)
        d_false = defer.Deferred()
        d_false.callback(False)
        #check regular correct update
        m._update_sth = mock.Mock(return_value=copy.deepcopy(d_true))
        m._update_entries = mock.Mock(return_value=copy.deepcopy(d_true))
        d = m.update().addCallback(self.assertTrue)
        d.addCallback(check_calls_sth_fails)
        d.addCallback(check_calls_entries_fail)
        d.addCallback(check_calls_assert_last_calls)
        return d

    def test_update_sth(self):
        client = FakeLogClient(self._NEW_STH)

        m = self.create_monitor(client)

        def check_state(result):
            # Check that we updated the state.
            expected_state = copy.deepcopy(self._DEFAULT_STATE)
            expected_state.pending_sth.CopyFrom(self._NEW_STH)
            self.verify_state(expected_state)
            audited_sths = list(self.db.scan_latest_sth_range(m.servername))
            self.assertEqual(len(audited_sths), 2)
            self.assertEqual(audited_sths[0].audit.status, client_pb2.VERIFIED)
            self.assertEqual(audited_sths[1].audit.status,
                             client_pb2.UNVERIFIED)

        return m._update_sth().addCallback(
            self.assertTrue).addCallback(check_state)

    def test_update_sth_fails_for_invalid_sth(self):
        client = FakeLogClient(self._NEW_STH)
        self.verifier.verify_sth.side_effect = error.VerifyError("Boom!")

        m = self.create_monitor(client)

        def check_state(result):
            # Check that we kept the state.
            self.verify_state(self._DEFAULT_STATE)
            self.check_db_state_after_successful_updates(0)

        return m._update_sth().addCallback(
            self.assertFalse).addCallback(check_state)

    def test_update_sth_fails_for_stale_sth(self):
        sth = client_pb2.SthResponse()
        sth.CopyFrom(self._DEFAULT_STH)
        sth.tree_size -= 1
        sth.timestamp -= 1
        client = FakeLogClient(sth)

        m = self.create_monitor(client)
        d = defer.Deferred()
        d.callback(True)
        m._verify_consistency = mock.Mock(return_value=d)

        def check_state(result):
            self.assertTrue(m._verify_consistency.called)
            args, _ = m._verify_consistency.call_args
            self.assertTrue(args[0].timestamp < args[1].timestamp)

            # Check that we kept the state.
            self.verify_state(self._DEFAULT_STATE)

        return m._update_sth().addCallback(
            self.assertFalse).addCallback(check_state)

    def test_update_sth_fails_for_inconsistent_sth(self):
        client = FakeLogClient(self._NEW_STH)
        # The STH is in fact OK but fake failure.
        self.verifier.verify_sth_consistency.side_effect = (
            error.ConsistencyError("Boom!"))

        m = self.create_monitor(client)

        def check_state(result):
            # Check that we kept the state.
            self.verify_state(self._DEFAULT_STATE)
            audited_sths = list(self.db.scan_latest_sth_range(m.servername))
            self.assertEqual(len(audited_sths), 2)
            self.assertEqual(audited_sths[0].audit.status,
                             client_pb2.VERIFY_ERROR)
            self.assertEqual(audited_sths[1].audit.status,
                             client_pb2.UNVERIFIED)
            for audited_sth in audited_sths:
                self.assertEqual(self._DEFAULT_STH.sha256_root_hash,
                                 audited_sth.sth.sha256_root_hash)

        return m._update_sth().addCallback(
            self.assertFalse).addCallback(check_state)

    def test_update_sth_fails_on_client_error(self):
        client = FakeLogClient(self._NEW_STH)

        def get_sth():
            return defer.maybeDeferred(
                mock.Mock(side_effect=log_client.HTTPError("Boom!")))

        client.get_sth = get_sth
        m = self.create_monitor(client)

        def check_state(result):
            # Check that we kept the state.
            self.verify_state(self._DEFAULT_STATE)
            self.check_db_state_after_successful_updates(0)

        return m._update_sth().addCallback(
            self.assertFalse).addCallback(check_state)

    def test_update_entries_fails_on_client_error(self):
        client = FakeLogClient(self._NEW_STH,
                               get_entries_throw=log_client.HTTPError("Boom!"))
        client.get_entries = mock.Mock(
            return_value=client.get_entries(0, self._NEW_STH.tree_size - 2))

        m = self.create_monitor(client)

        # Get the new STH, then try (and fail) to update entries
        d = m._update_sth().addCallback(self.assertTrue)
        d.addCallback(lambda x: m._update_entries()).addCallback(
            self.assertFalse)

        def check_state(result):
            # Check that we wrote no entries.
            expected_state = copy.deepcopy(self._DEFAULT_STATE)
            expected_state.pending_sth.CopyFrom(self._NEW_STH)
            self.verify_state(expected_state)

        d.addCallback(check_state)

        return d

    def test_update_entries_fails_not_enough_entries(self):
        client = FakeLogClient(self._NEW_STH)
        faker_fake_entry_producer = FakeEntryProducer(0,
                                                      self._NEW_STH.tree_size)
        faker_fake_entry_producer.change_range_after_start(0, 5)
        client.get_entries = mock.Mock(return_value=faker_fake_entry_producer)

        m = self.create_monitor(client)
        m._compute_projected_sth = self._NEW_STH_compute_projected
        # Get the new STH first.
        return m._update_sth().addCallback(self.assertTrue).addCallback(
            lambda x: m._update_entries().addCallback(self.assertFalse))

    def test_update_entries_fails_in_the_middle(self):
        client = FakeLogClient(self._NEW_STH)
        faker_fake_entry_producer = FakeEntryProducer(
            self._DEFAULT_STH.tree_size, self._NEW_STH.tree_size)
        faker_fake_entry_producer.change_range_after_start(
            self._DEFAULT_STH.tree_size, self._NEW_STH.tree_size - 5)
        client.get_entries = mock.Mock(return_value=faker_fake_entry_producer)

        m = self.create_monitor(client)
        m._compute_projected_sth = self._NEW_STH_compute_projected
        fake_fetch = mock.MagicMock()

        def try_again_with_all_entries(_):
            m._fetch_entries = fake_fetch
            return m._update_entries()

        # Get the new STH first.
        return m._update_sth().addCallback(self.assertTrue).addCallback(
            lambda _: m._update_entries().addCallback(self.assertFalse)
        ).addCallback(try_again_with_all_entries).addCallback(
            lambda _: fake_fetch.assert_called_once_with(15, 19))