def setUp(self): self.connection = MagicMock() connection_manager = MagicMock(client_uuid=uuid.uuid4()) connection_manager.get_random_connection_for_sql = MagicMock( return_value=self.connection) serialization_service = MagicMock() serialization_service.to_object.side_effect = lambda arg: arg serialization_service.to_data.side_effect = lambda arg: arg self.invocation_registry = {} correlation_id_counter = itertools.count() invocation_service = MagicMock() def invoke(invocation): self.invocation_registry[next(correlation_id_counter)] = invocation invocation_service.invoke.side_effect = invoke self.internal_service = _InternalSqlService(connection_manager, serialization_service, invocation_service) self.service = SqlService(self.internal_service) self.result = self.service.execute("SOME QUERY")
def __init__(self, **kwargs): config = _Config.from_dict(kwargs) self._config = config self._context = _ClientContext() client_id = HazelcastClient._CLIENT_ID.get_and_increment() self._name = self._create_client_name(client_id) self._reactor = AsyncoreReactor() self._serialization_service = SerializationServiceV1(config) self._near_cache_manager = NearCacheManager( config, self._serialization_service) self._internal_lifecycle_service = _InternalLifecycleService(config) self._lifecycle_service = LifecycleService( self._internal_lifecycle_service) self._invocation_service = InvocationService(self, config, self._reactor) self._address_provider = self._create_address_provider() self._internal_partition_service = _InternalPartitionService(self) self._partition_service = PartitionService( self._internal_partition_service, self._serialization_service) self._internal_cluster_service = _InternalClusterService(self, config) self._cluster_service = ClusterService(self._internal_cluster_service) self._connection_manager = ConnectionManager( self, config, self._reactor, self._address_provider, self._internal_lifecycle_service, self._internal_partition_service, self._internal_cluster_service, self._invocation_service, self._near_cache_manager, ) self._load_balancer = self._init_load_balancer(config) self._listener_service = ListenerService(self, config, self._connection_manager, self._invocation_service) self._proxy_manager = ProxyManager(self._context) self._cp_subsystem = CPSubsystem(self._context) self._proxy_session_manager = ProxySessionManager(self._context) self._transaction_manager = TransactionManager(self._context) self._lock_reference_id_generator = AtomicInteger(1) self._statistics = Statistics( self, config, self._reactor, self._connection_manager, self._invocation_service, self._near_cache_manager, ) self._cluster_view_listener = ClusterViewListenerService( self, self._connection_manager, self._internal_partition_service, self._internal_cluster_service, self._invocation_service, ) self._shutdown_lock = threading.RLock() self._invocation_service.init(self._internal_partition_service, self._connection_manager, self._listener_service) self._internal_sql_service = _InternalSqlService( self._connection_manager, self._serialization_service, self._invocation_service) self.sql = SqlService(self._internal_sql_service) self._init_context() self._start()
class SqlMockTest(unittest.TestCase): def setUp(self): self.connection = MagicMock() connection_manager = MagicMock(client_uuid=uuid.uuid4()) connection_manager.get_random_connection_for_sql = MagicMock( return_value=self.connection) serialization_service = MagicMock() serialization_service.to_object.side_effect = lambda arg: arg serialization_service.to_data.side_effect = lambda arg: arg self.invocation_registry = {} correlation_id_counter = itertools.count() invocation_service = MagicMock() def invoke(invocation): self.invocation_registry[next(correlation_id_counter)] = invocation invocation_service.invoke.side_effect = invoke self.internal_service = _InternalSqlService(connection_manager, serialization_service, invocation_service) self.service = SqlService(self.internal_service) self.result = self.service.execute("SOME QUERY") def test_iterator_with_rows(self): self.set_execute_response_with_rows() result = self.result.result() self.assertEqual(-1, result.update_count()) self.assertTrue(result.is_row_set()) self.assertIsInstance(result.get_row_metadata(), SqlRowMetadata) self.assertEqual(EXPECTED_ROWS, self.get_rows_from_iterator(result)) def test_blocking_iterator_with_rows(self): self.set_execute_response_with_rows() result = self.result.result() self.assertEqual(-1, result.update_count()) self.assertTrue(result.is_row_set()) self.assertIsInstance(result.get_row_metadata(), SqlRowMetadata) self.assertEqual(EXPECTED_ROWS, self.get_rows_from_blocking_iterator(result)) def test_iterator_with_update_count(self): self.set_execute_response_with_update_count() result = self.result.result() self.assertEqual(EXPECTED_UPDATE_COUNT, result.update_count()) self.assertFalse(result.is_row_set()) with self.assertRaises(ValueError): result.get_row_metadata() with self.assertRaises(ValueError): result.iterator() def test_blocking_iterator_with_update_count(self): self.set_execute_response_with_update_count() result = self.result.result() self.assertEqual(EXPECTED_UPDATE_COUNT, result.update_count()) self.assertFalse(result.is_row_set()) with self.assertRaises(ValueError): result.get_row_metadata() with self.assertRaises(ValueError): for _ in result: pass def test_execute_error(self): self.set_execute_error(RuntimeError("expected")) with self.assertRaises(HazelcastSqlError) as cm: result = self.result.result() iter(result) self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) def test_execute_error_when_connection_is_not_live(self): self.connection.live = False self.set_execute_error(RuntimeError("expected")) with self.assertRaises(HazelcastSqlError) as cm: result = self.result.result() iter(result) self.assertEqual(_SqlErrorCode.CONNECTION_PROBLEM, cm.exception._code) def test_close_when_close_request_fails(self): self.set_execute_response_with_rows(is_last=False) result = self.result.result() future = result.close() self.set_close_error( HazelcastSqlError(None, _SqlErrorCode.PARSING, "expected", None)) with self.assertRaises(HazelcastSqlError) as cm: future.result() self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code) def test_fetch_error(self): self.set_execute_response_with_rows(is_last=False) result = self.result.result() rows = [] i = result.iterator() # First page contains two rows rows.append(next(i).result().get_object_with_index(0)) rows.append(next(i).result().get_object_with_index(0)) self.assertEqual(EXPECTED_ROWS, rows) # initiate the fetch request future = next(i) self.set_fetch_error(RuntimeError("expected")) with self.assertRaises(HazelcastSqlError) as cm: future.result() self.assertEqual(_SqlErrorCode.GENERIC, cm.exception._code) def test_fetch_server_error(self): self.set_execute_response_with_rows(is_last=False) result = self.result.result() rows = [] i = result.iterator() # First page contains two rows rows.append(next(i).result().get_object_with_index(0)) rows.append(next(i).result().get_object_with_index(0)) self.assertEqual(EXPECTED_ROWS, rows) # initiate the fetch request future = next(i) self.set_fetch_response_with_error() with self.assertRaises(HazelcastSqlError) as cm: future.result() self.assertEqual(_SqlErrorCode.PARSING, cm.exception._code) def test_close_in_between_fetches(self): self.set_execute_response_with_rows(is_last=False) result = self.result.result() rows = [] i = result.iterator() # First page contains two rows rows.append(next(i).result().get_object_with_index(0)) rows.append(next(i).result().get_object_with_index(0)) self.assertEqual(EXPECTED_ROWS, rows) # initiate the fetch request future = next(i) result.close() with self.assertRaises(HazelcastSqlError) as cm: future.result() self.assertEqual(_SqlErrorCode.CANCELLED_BY_USER, cm.exception._code) def set_fetch_response_with_error(self): response = { "row_page": None, "error": _SqlError(_SqlErrorCode.PARSING, "expected", None, None, ""), } self.set_future_result_or_exception( response, sql_fetch_codec._REQUEST_MESSAGE_TYPE) def set_fetch_error(self, error): self.set_future_result_or_exception( error, sql_fetch_codec._REQUEST_MESSAGE_TYPE) def set_close_error(self, error): self.set_future_result_or_exception( error, sql_close_codec._REQUEST_MESSAGE_TYPE) def set_close_response(self): self.set_future_result_or_exception( None, sql_close_codec._REQUEST_MESSAGE_TYPE) def set_execute_response_with_update_count(self): self.set_execute_response(EXPECTED_UPDATE_COUNT, None, None, None) @staticmethod def get_rows_from_blocking_iterator(result): return [row.get_object_with_index(0) for row in result] @staticmethod def get_rows_from_iterator(result): rows = [] for row_future in result.iterator(): try: row = row_future.result() rows.append(row.get_object_with_index(0)) except StopIteration: break return rows def set_execute_response_with_rows(self, is_last=True): self.set_execute_response( -1, [SqlColumnMetadata("name", SqlColumnType.VARCHAR, True, True)], _SqlPage([SqlColumnType.VARCHAR], [EXPECTED_ROWS], is_last), None, ) def set_execute_response(self, update_count, row_metadata, row_page, error): response = { "update_count": update_count, "row_metadata": row_metadata, "row_page": row_page, "error": error, } self.set_future_result_or_exception( response, sql_execute_codec._REQUEST_MESSAGE_TYPE) def set_execute_error(self, error): self.set_future_result_or_exception( error, sql_execute_codec._REQUEST_MESSAGE_TYPE) def get_message_type(self, invocation): return LE_INT.unpack_from(invocation.request.buf, _OUTBOUND_MESSAGE_MESSAGE_TYPE_OFFSET)[0] def set_future_result_or_exception(self, value, message_type): for invocation in self.invocation_registry.values(): if self.get_message_type(invocation) == message_type: if isinstance(value, Exception): invocation.future.set_exception(value) else: invocation.future.set_result(value)