class ObjectDatabaseOverChannelTestsWithRedis(unittest.TestCase,
                                              ObjectDatabaseTests):
    @classmethod
    def setUpClass(cls):
        ObjectDatabaseTests.setUpClass()

    def setUp(self):
        self.tempDir = tempfile.TemporaryDirectory()
        self.tempDirName = self.tempDir.__enter__()
        self.auth_token = genToken()

        if hasattr(self, 'redisProcess') and self.redisProcess:
            self.redisProcess.terminate()
            self.redisProcess.wait()

        self.redisProcess = subprocess.Popen([
            "/usr/bin/redis-server", '--port', '1115', '--logfile',
            os.path.join(self.tempDirName, "log.txt"), "--dbfilename",
            "db.rdb", "--dir",
            os.path.join(self.tempDirName)
        ])
        time.sleep(.5)
        assert self.redisProcess.poll() is None

        redis.StrictRedis(db=0, decode_responses=True, port=1115).echo("hi")
        self.mem_store = RedisPersistence(port=1115)
        self.server = InMemServer(self.mem_store, self.auth_token)
        self.server._gc_interval = .1
        self.server.start()

    def createNewDb(self):
        return self.server.connect(self.auth_token)

    def tearDown(self):
        self.server.stop()
        self.redisProcess.terminate()
        self.redisProcess.wait()
        self.redisProcess = None
        self.tempDir.__exit__(None, None, None)

    def test_throughput(self):
        pass

    def test_object_versions_robust(self):
        pass

    def test_flush_db_works(self):
        pass
class RingInvariantTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        configureLogging('database_ring_invariant_test')

    def setUp(self):
        self.token = genToken()
        self.mem_store = InMemoryPersistence()
        self.server = InMemServer(self.mem_store, self.token)
        self.server.start()

    def createNewDb(self):
        return self.server.connect(self.token)

    def tearDown(self):
        self.server.stop()

    def test_ring_invariants_basic(self):
        db = self.createNewDb()
        db.subscribeToSchema(schema)
        with db.transaction():
            #create the empty ring
            r = Ring.New()
            for i in range(10):
                r.insert(i)
                self.assertEqual(r.check(), (i + 2, 0))

    def test_ring_invariants_reader_writer(self):
        db = self.createNewDb()
        db.subscribeToSchema(schema)

        with db.transaction():
            #create the empty ring
            r = Ring.New()

        def writeSome():
            with db.transaction():
                rings = Ring.lookupAll()
                ring = rings[numpy.random.choice(len(rings))]
                ring.insert(numpy.random.choice(10))

        def checkSome(lazy, k=None):
            db2 = self.createNewDb()

            if k is not None:
                assert lazy
                db2.subscribeToType(schema.Ring, lazySubscription=True)
                db2.subscribeToIndex(schema.Ring, k=k)
            else:
                db2.subscribeToSchema(schema)

            with db2.transaction():
                rings = Ring.lookupAll()
                return rings[numpy.random.choice(len(rings))].check()

        for i in range(100):
            writeSome()

            isLazy = (i % 2) == 0

            k = None if i % 5 != 3 or not isLazy else numpy.random.choice(10)

            print("Pass ", i, 'isLazy=', isLazy, 'k=', k)

            count, sum = checkSome(isLazy, k)
            self.assertEqual(count, i + 2)
            self.assertEqual(sum, 0)
class ObjectDatabaseOverChannelTests(unittest.TestCase, ObjectDatabaseTests):
    @classmethod
    def setUpClass(cls):
        ObjectDatabaseTests.setUpClass()

    def setUp(self):
        self.auth_token = genToken()

        self.mem_store = InMemoryPersistence()
        self.server = InMemServer(self.mem_store, self.auth_token)
        self.server._gc_interval = .1
        self.server.start()

    def createNewDb(self):
        return self.server.connect(self.auth_token)

    def tearDown(self):
        self.server.stop()

    def test_connection_without_auth_disconnects(self):
        db = DatabaseConnection(self.server.getChannel())

        old_interval = messages.getHeartbeatInterval()
        messages.setHeartbeatInterval(.25)

        try:
            with self.assertRaises(DisconnectedException):
                db.subscribeToSchema(schema)

        finally:
            messages.setHeartbeatInterval(old_interval)

    def test_heartbeats(self):
        old_interval = messages.getHeartbeatInterval()
        messages.setHeartbeatInterval(.25)

        try:
            db1 = self.createNewDb()
            db2 = self.createNewDb()

            db1.subscribeToSchema(core_schema)
            db2.subscribeToSchema(core_schema)

            with db1.view():
                self.assertTrue(len(core_schema.Connection.lookupAll()), 2)

            with db2.view():
                self.assertTrue(len(core_schema.Connection.lookupAll()), 2)

            db1._stopHeartbeating()

            db2.waitForCondition(
                lambda: len(core_schema.Connection.lookupAll()) == 1,
                5.0 * self.PERFORMANCE_FACTOR)

            with db2.view():
                self.assertEqual(len(core_schema.Connection.lookupAll()), 1)

            with self.assertRaises(DisconnectedException):
                with db1.view():
                    pass
        finally:
            messages.setHeartbeatInterval(old_interval)

    def test_multithreading_and_cleanup(self):
        """Verify that if one thread is subscribing and the other is repeatedly looking
        at indices, that everything works correctly."""

        try:
            #inject some behavior to slow down the checks so we can see if we're
            #failing this test.
            SetWithEdits.AGRESSIVELY_CHECK_SET_ADDS_NOT_CHANGING = True

            db1 = self.createNewDb()
            db1.subscribeToType(Counter)

            db2 = self.createNewDb()
            db2.subscribeToType(Counter)

            shouldStop = [False]
            isOK = []

            threadcount = 4

            def readerthread(db):
                c = None
                while not shouldStop[0]:
                    if numpy.random.uniform() < .5:
                        if c is None:
                            with db.transaction():
                                c = Counter(k=0)
                        else:
                            with db.transaction():
                                c.delete()
                                c = None
                    else:
                        with db.view():
                            Counter.lookupAny(k=0)

                isOK.append(True)

            threads = [
                threading.Thread(target=readerthread,
                                 args=(db1 if threadcount % 2 else db2, ))
                for _ in range(threadcount)
            ]
            for t in threads:
                t.start()

            time.sleep(1.0)

            shouldStop[0] = True

            for t in threads:
                t.join()

            self.assertTrue(len(isOK) == threadcount)
        finally:
            SetWithEdits.AGRESSIVELY_CHECK_SET_ADDS_NOT_CHANGING = False