Esempio n. 1
0
class BoltConnector(Connector):

    scheme = "bolt"

    @property
    def server_agent(self):
        cx = self.pool.acquire()
        try:
            return cx.server.agent
        finally:
            self.pool.release(cx)

    def open(self, cx_data):
        def connector(address_, **kwargs):
            return connect(address_, auth=cx_data["auth"], **kwargs)

        address = (cx_data["host"], cx_data["port"])
        self.pool = ConnectionPool(connector, address)

    def close(self):
        self.pool.close()

    def _run_1(self, statement, parameters, graph, keys, entities):
        cx = self.pool.acquire()
        hydrator = PackStreamHydrator(version=cx.protocol_version,
                                      graph=graph,
                                      keys=keys,
                                      entities=entities)
        dehydrated_parameters = hydrator.dehydrate(parameters)
        result = CypherResult(on_more=cx.fetch,
                              on_done=lambda: self.pool.release(cx))
        result.update_metadata({"connection": self.connection_data})

        def update_metadata_with_keys(metadata):
            result.update_metadata(metadata)
            hydrator.keys = result.keys()

        cx.run(statement,
               dehydrated_parameters or {},
               on_success=update_metadata_with_keys,
               on_failure=self._fail)
        cx.pull_all(on_records=lambda records: result.append_records(
            map(hydrator.hydrate, records)),
                    on_success=result.update_metadata,
                    on_failure=self._fail,
                    on_summary=result.done)
        cx.send()
        cx.fetch()
        return result

    def _run_in_tx(self, statement, parameters, tx, graph, keys, entities):
        self._assert_valid_tx(tx)

        def fetch():
            tx.fetch()

        def fail(metadata):
            self.transactions.remove(tx)
            self.pool.release(tx)
            self._fail(metadata)

        hydrator = PackStreamHydrator(version=tx.protocol_version,
                                      graph=graph,
                                      keys=keys,
                                      entities=entities)
        dehydrated_parameters = hydrator.dehydrate(parameters)
        result = CypherResult(on_more=fetch)
        result.update_metadata({"connection": self.connection_data})

        def update_metadata_with_keys(metadata):
            result.update_metadata(metadata)
            hydrator.keys = result.keys()

        tx.run(statement,
               dehydrated_parameters or {},
               on_success=update_metadata_with_keys,
               on_failure=fail)
        tx.pull_all(on_records=lambda records: result.append_records(
            map(hydrator.hydrate, records)),
                    on_success=result.update_metadata,
                    on_failure=fail,
                    on_summary=result.done)
        tx.send()
        result.keys()  # force receipt of RUN summary, to detect any errors
        return result

    @classmethod
    def _fail(cls, metadata):
        from py2neo.database import GraphError
        raise GraphError.hydrate(metadata)

    def run(self,
            statement,
            parameters=None,
            tx=None,
            graph=None,
            keys=None,
            entities=None):
        if tx is None:
            return self._run_1(statement, parameters, graph, keys, entities)
        else:
            return self._run_in_tx(statement, parameters, tx, graph, keys,
                                   entities)

    def begin(self):
        tx = self.pool.acquire()
        tx.begin()
        self.transactions.add(tx)
        return tx

    def commit(self, tx):
        self._assert_valid_tx(tx)
        self.transactions.remove(tx)
        tx.commit()
        tx.sync()
        self.pool.release(tx)

    def rollback(self, tx):
        self._assert_valid_tx(tx)
        self.transactions.remove(tx)
        tx.rollback()
        tx.sync()
        self.pool.release(tx)

    def sync(self, cx):
        cx.sync()
Esempio n. 2
0
class ConnectionPoolTestCase(TestCase):
    def setUp(self):
        self.pool = ConnectionPool(connector, ("127.0.0.1", 7687))

    def tearDown(self):
        self.pool.close()

    def assert_pool_size(self,
                         address,
                         expected_active,
                         expected_inactive,
                         pool=None):
        if pool is None:
            pool = self.pool
        try:
            connections = pool.connections[address]
        except KeyError:
            self.assertEqual(0, expected_active)
            self.assertEqual(0, expected_inactive)
        else:
            self.assertEqual(expected_active,
                             len([cx for cx in connections if cx.in_use]))
            self.assertEqual(expected_inactive,
                             len([cx for cx in connections if not cx.in_use]))

    def test_can_acquire(self):
        address = ("127.0.0.1", 7687)
        connection = self.pool.acquire_direct(address)
        assert connection.address == address
        self.assert_pool_size(address, 1, 0)

    def test_can_acquire_twice(self):
        address = ("127.0.0.1", 7687)
        connection_1 = self.pool.acquire_direct(address)
        connection_2 = self.pool.acquire_direct(address)
        assert connection_1.address == address
        assert connection_2.address == address
        assert connection_1 is not connection_2
        self.assert_pool_size(address, 2, 0)

    def test_can_acquire_two_addresses(self):
        address_1 = ("127.0.0.1", 7687)
        address_2 = ("127.0.0.1", 7474)
        connection_1 = self.pool.acquire_direct(address_1)
        connection_2 = self.pool.acquire_direct(address_2)
        assert connection_1.address == address_1
        assert connection_2.address == address_2
        self.assert_pool_size(address_1, 1, 0)
        self.assert_pool_size(address_2, 1, 0)

    def test_can_acquire_and_release(self):
        address = ("127.0.0.1", 7687)
        connection = self.pool.acquire_direct(address)
        self.assert_pool_size(address, 1, 0)
        self.pool.release(connection)
        self.assert_pool_size(address, 0, 1)

    def test_releasing_twice(self):
        address = ("127.0.0.1", 7687)
        connection = self.pool.acquire_direct(address)
        self.pool.release(connection)
        self.assert_pool_size(address, 0, 1)
        self.pool.release(connection)
        self.assert_pool_size(address, 0, 1)

    def test_cannot_acquire_after_close(self):
        with ConnectionPool(lambda a: QuickConnection(FakeSocket(a)),
                            ()) as pool:
            pool.close()
            with self.assertRaises(ServiceUnavailable):
                _ = pool.acquire_direct("X")

    def test_in_use_count(self):
        address = ("127.0.0.1", 7687)
        self.assertEqual(self.pool.in_use_connection_count(address), 0)
        connection = self.pool.acquire_direct(address)
        self.assertEqual(self.pool.in_use_connection_count(address), 1)
        self.pool.release(connection)
        self.assertEqual(self.pool.in_use_connection_count(address), 0)

    def test_max_conn_pool_size(self):
        with ConnectionPool(connector, (),
                            max_connection_pool_size=1,
                            connection_acquisition_timeout=0) as pool:
            address = ("127.0.0.1", 7687)
            pool.acquire_direct(address)
            self.assertEqual(pool.in_use_connection_count(address), 1)
            with self.assertRaises(ClientError):
                pool.acquire_direct(address)
            self.assertEqual(pool.in_use_connection_count(address), 1)

    def test_multithread(self):
        with ConnectionPool(connector, (),
                            max_connection_pool_size=5,
                            connection_acquisition_timeout=10) as pool:
            address = ("127.0.0.1", 7687)
            releasing_event = Event()

            # We start 10 threads to compete connections from pool with size of 5
            threads = []
            for i in range(10):
                t = Thread(target=acquire_release_conn,
                           args=(pool, address, releasing_event))
                t.start()
                threads.append(t)

            # The pool size should be 5, all are in-use
            self.assert_pool_size(address, 5, 0, pool)
            # Now we allow thread to release connections they obtained from pool
            releasing_event.set()

            # wait for all threads to release connections back to pool
            for t in threads:
                t.join()
            # The pool size is still 5, but all are free
            self.assert_pool_size(address, 0, 5, pool)