Beispiel #1
0
    def setUp(self):

        self.connection = MagicMock()

        connection_manager = MagicMock(client_uuid=uuid.uuid4())
        connection_manager.get_random_connection_for_sql = MagicMock(
            return_value=self.connection)

        serialization_service = MagicMock()
        serialization_service.to_object.side_effect = lambda arg: arg
        serialization_service.to_data.side_effect = lambda arg: arg

        self.invocation_registry = {}
        correlation_id_counter = itertools.count()
        invocation_service = MagicMock()

        def invoke(invocation):
            self.invocation_registry[next(correlation_id_counter)] = invocation

        invocation_service.invoke.side_effect = invoke

        self.internal_service = _InternalSqlService(connection_manager,
                                                    serialization_service,
                                                    invocation_service)
        self.service = SqlService(self.internal_service)
        self.result = self.service.execute("SOME QUERY")
 def __init__(self, **kwargs):
     config = _Config.from_dict(kwargs)
     self._config = config
     self._context = _ClientContext()
     client_id = HazelcastClient._CLIENT_ID.get_and_increment()
     self._name = self._create_client_name(client_id)
     self._reactor = AsyncoreReactor()
     self._serialization_service = SerializationServiceV1(config)
     self._near_cache_manager = NearCacheManager(
         config, self._serialization_service)
     self._internal_lifecycle_service = _InternalLifecycleService(config)
     self._lifecycle_service = LifecycleService(
         self._internal_lifecycle_service)
     self._invocation_service = InvocationService(self, config,
                                                  self._reactor)
     self._address_provider = self._create_address_provider()
     self._internal_partition_service = _InternalPartitionService(self)
     self._partition_service = PartitionService(
         self._internal_partition_service, self._serialization_service)
     self._internal_cluster_service = _InternalClusterService(self, config)
     self._cluster_service = ClusterService(self._internal_cluster_service)
     self._connection_manager = ConnectionManager(
         self,
         config,
         self._reactor,
         self._address_provider,
         self._internal_lifecycle_service,
         self._internal_partition_service,
         self._internal_cluster_service,
         self._invocation_service,
         self._near_cache_manager,
     )
     self._load_balancer = self._init_load_balancer(config)
     self._listener_service = ListenerService(self, config,
                                              self._connection_manager,
                                              self._invocation_service)
     self._proxy_manager = ProxyManager(self._context)
     self._cp_subsystem = CPSubsystem(self._context)
     self._proxy_session_manager = ProxySessionManager(self._context)
     self._transaction_manager = TransactionManager(self._context)
     self._lock_reference_id_generator = AtomicInteger(1)
     self._statistics = Statistics(
         self,
         config,
         self._reactor,
         self._connection_manager,
         self._invocation_service,
         self._near_cache_manager,
     )
     self._cluster_view_listener = ClusterViewListenerService(
         self,
         self._connection_manager,
         self._internal_partition_service,
         self._internal_cluster_service,
         self._invocation_service,
     )
     self._shutdown_lock = threading.RLock()
     self._invocation_service.init(self._internal_partition_service,
                                   self._connection_manager,
                                   self._listener_service)
     self._internal_sql_service = _InternalSqlService(
         self._connection_manager, self._serialization_service,
         self._invocation_service)
     self.sql = SqlService(self._internal_sql_service)
     self._init_context()
     self._start()
Beispiel #3
0
class SqlMockTest(unittest.TestCase):
    def setUp(self):

        self.connection = MagicMock()

        connection_manager = MagicMock(client_uuid=uuid.uuid4())
        connection_manager.get_random_connection_for_sql = MagicMock(
            return_value=self.connection)

        serialization_service = MagicMock()
        serialization_service.to_object.side_effect = lambda arg: arg
        serialization_service.to_data.side_effect = lambda arg: arg

        self.invocation_registry = {}
        correlation_id_counter = itertools.count()
        invocation_service = MagicMock()

        def invoke(invocation):
            self.invocation_registry[next(correlation_id_counter)] = invocation

        invocation_service.invoke.side_effect = invoke

        self.internal_service = _InternalSqlService(connection_manager,
                                                    serialization_service,
                                                    invocation_service)
        self.service = SqlService(self.internal_service)
        self.result = self.service.execute("SOME QUERY")

    def test_iterator_with_rows(self):
        self.set_execute_response_with_rows()
        result = self.result.result()

        self.assertEqual(-1, result.update_count())
        self.assertTrue(result.is_row_set())
        self.assertIsInstance(result.get_row_metadata(), SqlRowMetadata)
        self.assertEqual(EXPECTED_ROWS, self.get_rows_from_iterator(result))

    def test_blocking_iterator_with_rows(self):
        self.set_execute_response_with_rows()
        result = self.result.result()

        self.assertEqual(-1, result.update_count())
        self.assertTrue(result.is_row_set())
        self.assertIsInstance(result.get_row_metadata(), SqlRowMetadata)
        self.assertEqual(EXPECTED_ROWS,
                         self.get_rows_from_blocking_iterator(result))

    def test_iterator_with_update_count(self):
        self.set_execute_response_with_update_count()
        result = self.result.result()

        self.assertEqual(EXPECTED_UPDATE_COUNT, result.update_count())
        self.assertFalse(result.is_row_set())

        with self.assertRaises(ValueError):
            result.get_row_metadata()

        with self.assertRaises(ValueError):
            result.iterator()

    def test_blocking_iterator_with_update_count(self):
        self.set_execute_response_with_update_count()
        result = self.result.result()

        self.assertEqual(EXPECTED_UPDATE_COUNT, result.update_count())
        self.assertFalse(result.is_row_set())

        with self.assertRaises(ValueError):
            result.get_row_metadata()

        with self.assertRaises(ValueError):
            for _ in result:
                pass

    def test_execute_error(self):
        self.set_execute_error(RuntimeError("expected"))

        with self.assertRaises(HazelcastSqlError) as cm:
            result = self.result.result()
            iter(result)

        self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code)

    def test_execute_error_when_connection_is_not_live(self):
        self.connection.live = False
        self.set_execute_error(RuntimeError("expected"))

        with self.assertRaises(HazelcastSqlError) as cm:
            result = self.result.result()
            iter(result)

        self.assertEqual(_SqlErrorCode.CONNECTION_PROBLEM, cm.exception._code)

    def test_close_when_close_request_fails(self):
        self.set_execute_response_with_rows(is_last=False)
        result = self.result.result()

        future = result.close()
        self.set_close_error(
            HazelcastSqlError(None, _SqlErrorCode.PARSING, "expected", None))

        with self.assertRaises(HazelcastSqlError) as cm:
            future.result()

        self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code)

    def test_fetch_error(self):
        self.set_execute_response_with_rows(is_last=False)
        result = self.result.result()

        rows = []
        i = result.iterator()
        # First page contains two rows
        rows.append(next(i).result().get_object_with_index(0))
        rows.append(next(i).result().get_object_with_index(0))

        self.assertEqual(EXPECTED_ROWS, rows)

        # initiate the fetch request
        future = next(i)

        self.set_fetch_error(RuntimeError("expected"))

        with self.assertRaises(HazelcastSqlError) as cm:
            future.result()

        self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code)

    def test_fetch_server_error(self):
        self.set_execute_response_with_rows(is_last=False)
        result = self.result.result()

        rows = []
        i = result.iterator()
        # First page contains two rows
        rows.append(next(i).result().get_object_with_index(0))
        rows.append(next(i).result().get_object_with_index(0))

        self.assertEqual(EXPECTED_ROWS, rows)

        # initiate the fetch request
        future = next(i)

        self.set_fetch_response_with_error()

        with self.assertRaises(HazelcastSqlError) as cm:
            future.result()

        self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code)

    def test_close_in_between_fetches(self):
        self.set_execute_response_with_rows(is_last=False)
        result = self.result.result()

        rows = []
        i = result.iterator()
        # First page contains two rows
        rows.append(next(i).result().get_object_with_index(0))
        rows.append(next(i).result().get_object_with_index(0))

        self.assertEqual(EXPECTED_ROWS, rows)

        # initiate the fetch request
        future = next(i)

        result.close()

        with self.assertRaises(HazelcastSqlError) as cm:
            future.result()

        self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code)

    def set_fetch_response_with_error(self):
        response = {
            "row_page": None,
            "error": _SqlError(_SqlErrorCode.PARSING, "expected", None, None,
                               ""),
        }
        self.set_future_result_or_exception(
            response, sql_fetch_codec._REQUEST_MESSAGE_TYPE)

    def set_fetch_error(self, error):
        self.set_future_result_or_exception(
            error, sql_fetch_codec._REQUEST_MESSAGE_TYPE)

    def set_close_error(self, error):
        self.set_future_result_or_exception(
            error, sql_close_codec._REQUEST_MESSAGE_TYPE)

    def set_close_response(self):
        self.set_future_result_or_exception(
            None, sql_close_codec._REQUEST_MESSAGE_TYPE)

    def set_execute_response_with_update_count(self):
        self.set_execute_response(EXPECTED_UPDATE_COUNT, None, None, None)

    @staticmethod
    def get_rows_from_blocking_iterator(result):
        return [row.get_object_with_index(0) for row in result]

    @staticmethod
    def get_rows_from_iterator(result):
        rows = []
        for row_future in result.iterator():
            try:
                row = row_future.result()
                rows.append(row.get_object_with_index(0))
            except StopIteration:
                break
        return rows

    def set_execute_response_with_rows(self, is_last=True):
        self.set_execute_response(
            -1,
            [SqlColumnMetadata("name", SqlColumnType.VARCHAR, True, True)],
            _SqlPage([SqlColumnType.VARCHAR], [EXPECTED_ROWS], is_last),
            None,
        )

    def set_execute_response(self, update_count, row_metadata, row_page,
                             error):
        response = {
            "update_count": update_count,
            "row_metadata": row_metadata,
            "row_page": row_page,
            "error": error,
        }

        self.set_future_result_or_exception(
            response, sql_execute_codec._REQUEST_MESSAGE_TYPE)

    def set_execute_error(self, error):
        self.set_future_result_or_exception(
            error, sql_execute_codec._REQUEST_MESSAGE_TYPE)

    def get_message_type(self, invocation):
        return LE_INT.unpack_from(invocation.request.buf,
                                  _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET)[0]

    def set_future_result_or_exception(self, value, message_type):
        for invocation in self.invocation_registry.values():
            if self.get_message_type(invocation) == message_type:
                if isinstance(value, Exception):
                    invocation.future.set_exception(value)
                else:
                    invocation.future.set_result(value)