def test_init_session_failed_connection(self): # Setup: # ... Create OE service with mock connection service that returns a failed connection response cs = ConnectionService() connect_response = ConnectionCompleteParams() connect_response.error_message = 'Boom!' cs.connect = mock.MagicMock(return_value=connect_response) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) # If: I initialize a session (NOTE: We're bypassing request handler to avoid threading issues) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) oe._session_map[session_uri] = session rc = RequestFlowValidator() rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._initialize_session(rc.request_context, session) # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() self.assertDictEqual(oe._session_map, {})
def test_handle_create_session_successful(self): # Setup: # ... Create OE service with mock connection service that returns a successful connection response mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******', port=123) cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) oe._provider = constants.PG_PROVIDER_NAME oe._server = Server # ... Create parameters, session, request context validator params, session_uri = _connection_details() # ... Create validation of success notification def validate_success_notification(response: SessionCreatedParameters): self.assertTrue(response.success) self.assertEqual(response.session_id, session_uri) self.assertIsNone(response.error_message) self.assertIsInstance(response.root_node, NodeInfo) self.assertEqual(response.root_node.label, TEST_DBNAME) self.assertEqual(response.root_node.node_path, session_uri) self.assertEqual(response.root_node.node_type, 'Database') self.assertIsInstance(response.root_node.metadata, ObjectMetadata) self.assertEqual(response.root_node.metadata.urn, oe._session_map[session_uri].server.urn_base) self.assertEqual( response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name) self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database') self.assertFalse(response.root_node.is_leaf) rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification(SessionCreatedParameters, SESSION_CREATED_METHOD, validate_success_notification) # If: I create a session oe._handle_create_session_request(rc.request_context, params) oe._session_map[session_uri].init_task.join() # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() # ... The session should still exist and should have connection and server setup self.assertIn(session_uri, oe._session_map) self.assertIsInstance(oe._session_map[session_uri].server, Server) self.assertTrue(oe._session_map[session_uri].is_ready)
def test_handle_scriptas_successful_operation(self): # NOTE: There's no need to test all types here, the scripter tests should handle this # Setup: # ... Create a scripting service mock_connection = MockPGServerConnection() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider( {CONNECTION_SERVICE_NAME: cs}) # ... Create validation logic for responses def validate_response(response: ScriptAsResponse) -> None: self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI) self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT) # ... Create a scripter with mocked out calls patch_path = 'ossdbtoolsservice.scripting.scripting_service.Scripter' with mock.patch(patch_path) as scripter_patch: mock_scripter: Scripter = Scripter(mock_connection) mock_scripter.script = mock.MagicMock( return_value=TestScriptingService.MOCK_SCRIPT) scripter_patch.return_value = mock_scripter scripting_object = { 'type': 'Table', 'name': 'test_table', 'schema': 'test_schema' } # For each operation supported for operation in ScriptOperation: # If: I request to script rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_response(ScriptAsResponse, validate_response) params = ScriptAsParameters.from_dict({ 'ownerUri': TestScriptingService.MOCK_URI, 'operation': operation, 'scripting_objects': [scripting_object] }) ss._handle_scriptas_request(rc.request_context, params) # Then: # ... The request should have been handled correctly rc.validate() # ... All of the scripter methods should have been called once matches = {operation: 0 for operation in ScriptOperation} for call_args in mock_scripter.script.call_args_list: matches[call_args[0][0]] += 1 for calls in matches.values(): self.assertEqual(calls, 1)
def test_create_connection_successful(self): # Setup: mock_connection = MockPGServerConnection() oe = ObjectExplorerService() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) connection = oe._create_connection(session, 'foo_database') self.assertIsNotNone(connection) self.assertEqual(connection, mock_connection) cs.connect.assert_called_once() cs.get_connection.assert_called_once()
def test_create_connection_failed(self): # Setup: oe = ObjectExplorerService() cs = ConnectionService() connect_response = ConnectionCompleteParams() error = 'Failed' connect_response.error_message = error cs.connect = mock.MagicMock(return_value=connect_response) oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) with self.assertRaises(RuntimeError) as context: oe._create_connection(session, 'foo_database') self.assertEqual(error, str(context.exception)) cs.connect.assert_called_once()
def test_handle_scriptas_invalid_operation(self): # Setup: Create a scripting service mock_connection = {} cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider( {CONNECTION_SERVICE_NAME: cs}) # If: I create an OE session with missing params rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation) ss._handle_scriptas_request(rc.request_context, None) # Then: # ... I should get an error response rc.validate()
class TestMySQLConnectionService(unittest.TestCase): """Methods for testing the connection service with a MySQL Connection""" def setUp(self): """Set up the tests with a connection service""" self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider( {WORKSPACE_SERVICE_NAME: WorkspaceService()}, provider_name=MYSQL_PROVIDER_NAME) mock_cursor = MockCursor(results=[['5.7.29-log']]) # Set up the mock connection for pymysql's connect method to return self.mock_pymysql_connection = MockPyMySQLConnection( parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }, cursor=mock_cursor) def test_connect(self): """Test that the service connects to a MySQL server""" # Set up the parameters for the connection params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'password': '******', 'host': 'myserver', 'dbname': 'mysql' } } }) # Set up the connection service and call its connect method with the supported options with mock.patch( 'pymysql.connect', new=mock.Mock(return_value=self.mock_pymysql_connection)): response = self.connection_service.connect(params) # Verify that pymysql's connection method was called and that the # response has a connection id, indicating success. self.assertIs( self.connection_service.owner_to_connection_map[ params.owner_uri].get_connection(params.type)._conn, self.mock_pymysql_connection) self.assertIsNotNone(response.connection_id) self.assertIsNotNone(response.server_info.server_version) self.assertFalse(response.server_info.is_cloud) def test_connect_with_access_token(self): """Test that the service connects to a MySQL server using an access token as a password""" # Set up the parameters for the connection params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'azureAccountToken': 'exampleToken', 'host': 'myserver', 'dbname': 'mysql' } } }) # Set up pymysql instance for connection service to call mock_connect_method = mock.Mock( return_value=self.mock_pymysql_connection) # Set up the connection service and call its connect method with the supported options with mock.patch('pymysql.connect', new=mock_connect_method): response = self.connection_service.connect(params) # Verify that pymysql's connection method was called with password set to account token. mock_connect_method.assert_called_once_with( user='******', password='******', host='myserver', port=DEFAULT_PORT[MYSQL_PROVIDER_NAME], database='mysql') # Verify that pymysql's connection method was called and that the # response has a connection id, indicating success. self.assertIs( self.connection_service.owner_to_connection_map[ params.owner_uri].get_connection(params.type)._conn, self.mock_pymysql_connection) self.assertIsNotNone(response.connection_id) self.assertIsNotNone(response.server_info.server_version) self.assertFalse(response.server_info.is_cloud)
class TestConnectionCancellation(unittest.TestCase): """Methods for testing connection cancellation requests""" def setUp(self): """Set up the tests with common connection parameters""" # Set up the mock connection service and connection info self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider( {WORKSPACE_SERVICE_NAME: WorkspaceService()}) self.owner_uri = 'test_uri' self.connection_type = ConnectionType.DEFAULT self.connect_params: ConnectRequestParams = ConnectRequestParams.from_dict( { 'ownerUri': self.owner_uri, 'type': self.connection_type, 'connection': { 'options': {} } }) self.mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) # Mock psycopg2's connect method to store the current cancellation token. This lets us # capture the cancellation token state as it would be during a long-running connection. self.token_store = [] def test_connecting_sets_cancellation_token(self): """Test that a cancellation token is set before a connection thread attempts to connect""" # If I attempt to connect with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=self._mock_connect)): response = self.connection_service.connect(self.connect_params) # Then the cancellation token should have been set and should not have been canceled self.assertEqual(len(self.token_store), 1) self.assertFalse(self.token_store[0].canceled) # And the cancellation token should have been cleared when the connection succeeded self.assertIsNone(response.error_message) self.assertFalse(( self.owner_uri, self.connection_type) in self.connection_service._cancellation_map) def test_connection_failed_removes_own_token(self): """Test that the cancellation token is removed after a connection fails""" # If I attempt to connect with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=Exception())): response = self.connection_service.connect(self.connect_params) # Then the cancellation token should have been cleared when the connection failed self.assertIsNotNone(response.error_message) self.assertFalse(( self.owner_uri, self.connection_type) in self.connection_service._cancellation_map) def test_connecting_cancels_previous_connection(self): """Test that opening a new connection while one is ongoing cancels the previous connection""" # Set up psycopg2's connection method to kick off a new connection. This simulates the case # where a call to psycopg2.connect is taking a long time and another connection request for # the same URI and connection type comes in and finishes before the current connection with mock.patch( 'psycopg2.connect', new=mock.Mock( side_effect=self._mock_connect)) as mock_psycopg2_connect: old_mock_connect = mock_psycopg2_connect.side_effect def first_mock_connect(**kwargs): """Mock connection method to store the current cancellation token, and kick off another connection""" mock_connection = self._mock_connect() mock_psycopg2_connect.side_effect = old_mock_connect self.connection_service.connect(self.connect_params) return mock_connection mock_psycopg2_connect.side_effect = first_mock_connect # If I attempt to connect, and then kick off a new connection while connecting response = self.connection_service.connect(self.connect_params) # Then the connection should have been canceled and returned none self.assertIsNone(response) # And the recorded cancellation tokens should show that the first request was canceled self.assertEqual(len(self.token_store), 2) self.assertTrue(self.token_store[0].canceled) self.assertFalse(self.token_store[1].canceled) def test_newer_cancellation_token_not_removed(self): """Test that a newer connection's cancellation token is not removed after a connection completes""" # Set up psycopg2's connection method to simulate a new connection by overriding the # current cancellation token. This simulates the case where a call to psycopg2.connect is # taking a long time and another connection request for the same URI and connection type # comes in and finishes after the current connection cancellation_token = CancellationToken() cancellation_key = (self.owner_uri, self.connection_type) def override_mock_connect(**kwargs): """Mock connection method to override the current connection token, as if another connection is executing""" mock_connection = self._mock_connect() self.connection_service._cancellation_map[cancellation_key].cancel( ) self.connection_service._cancellation_map[ cancellation_key] = cancellation_token return mock_connection with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=override_mock_connect)): # If I attempt to connect, and the cancellation token gets updated while connecting response = self.connection_service.connect(self.connect_params) # Then the connection should have been canceled and returned none self.assertIsNone(response) # And the current cancellation token should not have been removed self.assertIs( self.connection_service._cancellation_map[cancellation_key], cancellation_token) def test_handle_cancellation_request(self): """Test that handling a cancellation request modifies the cancellation token for a matched connection""" # Set up the connection service with a mock request handler and cancellation token cancellation_key = (self.owner_uri, self.connection_type) cancellation_token = CancellationToken() self.connection_service._cancellation_map[ cancellation_key] = cancellation_token request_context = utils.MockRequestContext() # If I call the cancellation request handler cancel_params = CancelConnectParams(self.owner_uri, self.connection_type) self.connection_service.handle_cancellation_request( request_context, cancel_params) # Then the handler should have responded and set the cancellation flag request_context.send_response.assert_called_once_with(True) self.assertTrue(cancellation_token.canceled) def test_handle_cancellation_no_match(self): """Test that handling a cancellation request returns false if there is no matching connection to cancel""" # Set up a mock request handler request_context = utils.MockRequestContext() # If I call the cancellation request handler cancel_params = CancelConnectParams(self.owner_uri, self.connection_type) self.connection_service.handle_cancellation_request( request_context, cancel_params) # Then the handler should have responded false to indicate that no matching connection was in progress request_context.send_response.assert_called_once_with(False) def test_connect_with_access_token(self): """Test that the service connects to a PostgreSQL server using an access token as a password""" # Set up the parameters for the connection params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'azureAccountToken': 'exampleToken', 'host': 'myserver', 'dbname': 'postgres' } } }) # Set up psycopg2 instance for connection service to call mock_connect_method = mock.Mock( return_value=self.mock_psycopg_connection) # Set up the connection service and call its connect method with the supported options with mock.patch('psycopg2.connect', new=mock_connect_method): response = self.connection_service.connect(params) # Verify that psycopg2's connection method was called with password set to account token. mock_connect_method.assert_called_once_with( user='******', password='******', host='myserver', port=DEFAULT_PORT[PG_PROVIDER_NAME], dbname='postgres') # Verify that psycopg2's connection method was called and that the # response has a connection id, indicating success. self.assertIs( self.connection_service.owner_to_connection_map[ params.owner_uri].get_connection(params.type)._conn, self.mock_psycopg_connection) self.assertIsNotNone(response.connection_id) self.assertIsNotNone(response.server_info.server_version) self.assertFalse(response.server_info.is_cloud) def _mock_connect(self, **kwargs): """Implementation for the mock psycopg2.connect method that saves the current cancellation token""" self.token_store.append( self.connection_service._cancellation_map[(self.owner_uri, self.connection_type)]) return self.mock_psycopg_connection
class TestPGConnectionService(unittest.TestCase): """Methods for testing the connection service with a PG connection""" def setUp(self): """Set up the tests with a connection service""" self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider( {WORKSPACE_SERVICE_NAME: WorkspaceService()}) def test_connect(self): """Test that the service connects to a PostgreSQL server""" # Set up the parameters for the connection params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'password': '******', 'host': 'myserver', 'dbname': 'postgres' } } }) # Set up the mock connection for psycopg2's connect method to return mock_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) # Set up the connection service and call its connect method with the supported options with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): response = self.connection_service.connect(params) # Verify that psycopg2's connection method was called and that the # response has a connection id, indicating success. self.assertIs( self.connection_service.owner_to_connection_map[ params.owner_uri].get_connection(params.type)._conn, mock_connection) self.assertIsNotNone(response.connection_id) self.assertIsNotNone(response.server_info.server_version) self.assertFalse(response.server_info.is_cloud) def test_server_info_is_cloud(self): """Test that the connection response handles cloud connections correctly""" self.server_info_is_cloud_internal('postgres.database.azure.com', True) self.server_info_is_cloud_internal('postgres.database.windows.net', True) self.server_info_is_cloud_internal('some.host.com', False) def server_info_is_cloud_internal(self, host_suffix, is_cloud): """Test that the connection response handles cloud connections correctly""" # Set up the parameters for the connection connection_uri = 'someuri' connection_details = ConnectionDetails() connection_details.options = { 'user': '******', 'password': '******', 'host': f'myserver{host_suffix}', 'dbname': 'postgres' } connection_type = ConnectionType.DEFAULT # Set up the mock connection for psycopg2's connect method to return mock_connection = MockPsycopgConnection( dsn_parameters={ 'host': f'myserver{host_suffix}', 'dbname': 'postgres', 'user': '******' }) # Set up the connection service and call its connect method with the # supported options with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): response = self.connection_service.connect( ConnectRequestParams(connection_details, connection_uri, connection_type)) # Verify that the response's serverInfo.isCloud attribute is set correctly self.assertIsNotNone(response.connection_id) self.assertIsNotNone(response.server_info.server_version) self.assertEqual(response.server_info.is_cloud, is_cloud) def test_changing_options_disconnects_existing_connection(self): """ Test that the connect method disconnects an existing connection when trying to open the same connection with different options """ # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.DEFAULT mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({ 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 123 }) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type, mock_connection) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Create a different request with the same owner uri params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': connection_uri, 'type': connection_type, 'connection': { 'options': { 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 234 } } }) # Connect with different options, and verify that disconnect was called with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): self.connection_service.connect(params) mock_connection.close.assert_called_once() def test_same_options_uses_existing_connection(self): """Test that the connect method uses an existing connection when connecting again with the same options""" # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.DEFAULT mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) mock_server_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({ 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 123 }) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type, mock_server_connection) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Connect with identical options, and verify that disconnect was not called params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': connection_uri, 'type': connection_type, 'connection': { 'options': old_connection_details.options } }) with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_psycopg_connection) ) as mock_psycopg2_connect: response = self.connection_service.connect(params) mock_psycopg2_connect.assert_not_called() mock_psycopg_connection.close.assert_not_called() self.assertIsNotNone(response.connection_id) def test_response_when_connect_fails(self): """Test that the proper response is given when a connection fails""" error_message = 'some error' params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'connectionString': '' } } }) with mock.patch('psycopg2.connect', new=mock.Mock(side_effect=Exception(error_message))): response = self.connection_service.connect(params) # The response should not have a connection ID and should contain the error message self.assertIsNone(response.connection_id) self.assertEqual(response.error_message, error_message) def test_register_on_connect_callback(self): """Tests that callbacks are added to a list of callbacks as expected""" callback = MagicMock() self.connection_service.register_on_connect_callback(callback) self.assertListEqual(self.connection_service._on_connect_callbacks, [callback]) def test_on_connect_backs_called_on_connect(self): self.run_on_connect_callback(ConnectionType.DEFAULT, True) self.run_on_connect_callback(ConnectionType.EDIT, False) self.run_on_connect_callback(ConnectionType.INTELLISENSE, False) self.run_on_connect_callback(ConnectionType.QUERY, False) def run_on_connect_callback(self, conn_type: ConnectionType, expect_callback: bool) -> None: """Inner function for callback tests that verifies expected behavior given different connection types""" callbacks = [MagicMock(), MagicMock()] for callback in callbacks: self.connection_service.register_on_connect_callback(callback) # Set up the parameters for the connection connection_uri = 'someuri' connection_details = ConnectionDetails() connection_details.options = { 'user': '******', 'password': '******', 'host': f'myserver', 'dbname': 'postgres' } connection_type = conn_type # Set up the mock connection for psycopg2's connect method to return mock_psycopg_connection = MockPsycopgConnection( dsn_parameters={ 'host': f'myserver', 'dbname': 'postgres', 'user': '******' }) # Set up the connection service and call its connect method with the # supported options with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_psycopg_connection)): self.connection_service.connect( ConnectRequestParams(connection_details, connection_uri, connection_type)) self.connection_service.get_connection(connection_uri, conn_type) # ... The mock config change callbacks should have been called for callback in callbacks: if (expect_callback): callback.assert_called_once() # Verify call args match expected callargs: ConnectionInfo = callback.call_args[0][0] self.assertEqual(callargs.owner_uri, connection_uri) else: callback.assert_not_called() def test_disconnect_single_type(self): """Test that the disconnect method calls close on a single open connection type when a type is given""" # Set up the test with mock data connection_uri = 'someuri' connection_type_1 = ConnectionType.DEFAULT connection_type_2 = ConnectionType.EDIT mock_connection_1 = MockPGServerConnection(cur=None, host='myserver1', name='postgres1', user='******') mock_connection_2 = MockPGServerConnection(cur=None, host='myserver2', name='postgres2', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({'abc': 123}) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type_1, mock_connection_1) old_connection_info.add_connection(connection_type_2, mock_connection_2) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Close the connection by calling disconnect response = self.connection_service._close_connections( old_connection_info, connection_type_1) mock_connection_1.close.assert_called_once() mock_connection_2.close.assert_not_called() self.assertTrue(response) def test_disconnect_all_types(self): """Test that the disconnect method calls close on a all open connection types when no type is given""" # Set up the test with mock data connection_uri = 'someuri' connection_type_1 = ConnectionType.DEFAULT connection_type_2 = ConnectionType.EDIT mock_connection_1 = MockPGServerConnection(cur=None, host='myserver1', name='postgres1', user='******') mock_connection_2 = MockPGServerConnection(cur=None, host='myserver2', name='postgres2', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({'abc': 123}) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type_1, mock_connection_1) old_connection_info.add_connection(connection_type_2, mock_connection_2) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Close the connection by calling disconnect response = self.connection_service._close_connections( old_connection_info) mock_connection_1.close.assert_called_once() mock_connection_2.close.assert_called_once() self.assertTrue(response) def test_disconnect_for_invalid_connection(self): """Test that the disconnect method returns false when called on a connection that does not exist""" # Set up the test with mock data connection_uri = 'someuri' connection_type_1 = ConnectionType.DEFAULT mock_connection_1 = MockPGServerConnection(cur=None, host='myserver1', name='postgres1', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({'abc': 123}) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type_1, mock_connection_1) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Close the connection by calling disconnect response = self.connection_service._close_connections( old_connection_info, ConnectionType.EDIT) mock_connection_1.close.assert_not_called() self.assertFalse(response) def test_handle_disconnect_request_unknown_uri(self): """Test that the handle_disconnect_request method returns false when the given URI is unknown""" # Setup: Create a mock request context rc = utils.MockRequestContext() # If: I request to disconnect an unknown URI params: DisconnectRequestParams = DisconnectRequestParams.from_dict( {'ownerUri': 'nonexistent'}) self.connection_service.handle_disconnect_request(rc, params) # Then: Send result should have been called once with False rc.send_response.assert_called_once_with(False) rc.send_notification.assert_not_called() rc.send_error.assert_not_called() def test_handle_connect_request(self): """Test that the handle_connect_request method kicks off a new thread to do the connection""" # Setup: Create a mock request context to handle output rc = utils.MockRequestContext() connect_response = ConnectionCompleteParams() self.connection_service.connect = Mock(return_value=connect_response) # If: I make a request to connect params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.QUERY, 'connection': { 'server_name': 'someserver', 'user_name': 'someuser', 'database_name': 'somedb', 'options': { 'password': '******' } } }) # Connect and wait for the thread to finish executing, then verify the connection information self.connection_service.handle_connect_request(rc, params) connection_thread = self.connection_service.owner_to_thread_map[ params.owner_uri] self.assertIsNotNone(connection_thread) connection_thread.join() # Then: # ... Connect should have been called once self.connection_service.connect.assert_called_once_with(params) # ... A True should have been sent as the response to the request rc.send_response.assert_called_once_with(True) # ... A connection complete notification should have been sent back as well rc.send_notification.assert_called_once_with( CONNECTION_COMPLETE_METHOD, connect_response) # ... An error should not have been called rc.send_error.assert_not_called() def test_handle_database_change_request_with_empty_connection_info_for_false( self): """Test that the handle_connect_request method kicks off a new thread to do the connection""" # Setup: Create a mock request context to handle output rc = utils.MockRequestContext() connect_response = ConnectionCompleteParams() self.connection_service.connect = Mock(return_value=connect_response) params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict( { 'owner_uri': 'someUri', 'new_database': 'newDb' }) result: bool = self.connection_service.handle_change_database_request( rc, params) self.assertEqual(result, False) def test_handle_database_change_request(self): """Test that the handle_connect_request method kicks off a new thread to do the connection""" # Setup: Create a mock request context to handle output rc = utils.MockRequestContext() connect_response = ConnectionCompleteParams() self.connection_service.connect = Mock(return_value=connect_response) self.connection_service.get_connection_info = mock.MagicMock() params: ChangeDatabaseRequestParams = ChangeDatabaseRequestParams.from_dict( { 'owner_uri': 'someUri', 'new_database': 'newDb' }) self.connection_service.handle_change_database_request(rc, params) connection_thread = self.connection_service.owner_to_thread_map[ params.owner_uri] self.assertIsNotNone(connection_thread) connection_thread.join() self.connection_service.connect.assert_called_once() # ... A True should have been sent as the response to the request rc.send_response.assert_called_once_with(True) # ... A connection complete notification should have been sent back as well rc.send_notification.assert_called_once_with( CONNECTION_COMPLETE_METHOD, connect_response) # ... An error should not have been called rc.send_error.assert_not_called() def test_list_databases(self): """Test that the list databases handler correctly lists the connection's databases""" # Set up the test with mock data mock_query_results = [('database1', ), ('database2', )] connection_uri = 'someuri' mock_psycopg_connection = MockPsycopgConnection( dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }, cursor=MockCursor(mock_query_results)) mock_request_context = utils.MockRequestContext() # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info # Verify that calling the listdatabases handler returns the expected databases params = ListDatabasesParams() params.owner_uri = connection_uri with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_psycopg_connection)): self.connection_service.handle_list_databases( mock_request_context, params) expected_databases = [result[0] for result in mock_query_results] self.assertEqual( mock_request_context.last_response_params.database_names, expected_databases) def test_get_connection_for_existing_connection(self): """Test that get_connection returns a connection that already exists for the given URI and type""" # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.EDIT mock_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info # Get the connection without first creating it with mock.patch( 'psycopg2.connect', new=mock.Mock( return_value=mock_connection)) as mock_psycopg2_connect: connection = self.connection_service.get_connection( connection_uri, connection_type) mock_psycopg2_connect.assert_called_once() self.assertEqual(connection._conn, mock_connection) def test_get_connection_creates_connection(self): """Test that get_connection creates a new connection when none exists for the given URI and type""" # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.EDIT mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info with mock.patch( 'ossdbtoolsservice.driver.connection_manager.ConnectionManager._create_connection', new=mock.Mock( return_value=mock_connection)) as mock_psycopg2_connect: # Open the connection self.connection_service.connect( ConnectRequestParams(connection_details, connection_uri, connection_type)) # Get the connection connection = self.connection_service.get_connection( connection_uri, connection_type) self.assertEqual(connection, mock_connection) mock_psycopg2_connect.assert_called_once() def test_get_connection_for_invalid_uri(self): """Test that get_connection raises an error if the given URI is unknown""" with self.assertRaises(ValueError): self.connection_service.get_connection('someuri', ConnectionType.DEFAULT) def test_list_databases_handles_invalid_uri(self): """Test that the connection/listdatabases handler returns an error when the given URI is unknown""" mock_request_context = utils.MockRequestContext() params = ListDatabasesParams() params.owner_uri = 'unknown_uri' self.connection_service.handle_list_databases(mock_request_context, params) self.assertIsNone(mock_request_context.last_notification_method) self.assertIsNone(mock_request_context.last_notification_params) self.assertIsNone(mock_request_context.last_response_params) self.assertIsNotNone(mock_request_context.last_error_message) def test_list_databases_handles_query_failure(self): """Test that the list databases handler returns an error if the list databases query fails for any reason""" # Set up the test with mock data mock_query_results = [('database1', ), ('database2', )] connection_uri = 'someuri' mock_cursor = MockCursor(mock_query_results) mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('') mock_connection = MockPGServerConnection(cur=mock_cursor, host='myserver', name='postgres', user='******') mock_request_context = utils.MockRequestContext() # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info # Verify that calling the listdatabases handler returns the expected # databases params = ListDatabasesParams() params.owner_uri = connection_uri with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): self.connection_service.handle_list_databases( mock_request_context, params) self.assertIsNone(mock_request_context.last_notification_method) self.assertIsNone(mock_request_context.last_notification_params) self.assertIsNone(mock_request_context.last_response_params) self.assertIsNotNone(mock_request_context.last_error_message) def test_build_connection_response(self): """Test that the connection response is built correctly""" # Set up the test with mock data server_name = 'testserver' db_name = 'testdb' user = '******' mock_connection = MockPGServerConnection(cur=None, host=server_name, name=db_name, user=user) connection_type = ConnectionType.EDIT connection_details = ConnectionDetails.from_data(opts={}) owner_uri = 'test_uri' connection_info = ConnectionInfo(owner_uri, connection_details) connection_info._connection_map = {connection_type: mock_connection} # If I build a connection response for the connection response = ossdbtoolsservice.connection.connection_service._build_connection_response( connection_info, connection_type) # Then the response should have accurate information about the connection self.assertEqual(response.owner_uri, owner_uri) self.assertEqual( response.server_info.server_version, str(mock_connection.server_version[0]) + "." + str(mock_connection.server_version[1]) + "." + str(mock_connection.server_version[2])) self.assertEqual(response.server_info.is_cloud, False) self.assertEqual(response.connection_summary.server_name, server_name) self.assertEqual(response.connection_summary.database_name, db_name) self.assertEqual(response.connection_summary.user_name, user) self.assertEqual(response.type, connection_type) def test_default_database(self): """Test that if no database is given, the default database is used""" # Set up the connection params and default database name default_db = 'test_db' self.connection_service._service_provider[ WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database = default_db params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'password': '******', 'host': 'myserver', 'dbname': '' } } }) # If I connect with an empty database name with mock.patch('ossdbtoolsservice.connection.connection_service._build_connection_response'), \ mock.patch('psycopg2.connect', return_value=MockPsycopgConnection()) as mock_psycopg2_connect: self.connection_service.connect(params) # Then psycopg2's connect method was called with the default database calls = mock_psycopg2_connect.mock_calls self.assertEqual(len(calls), 1) self.assertEqual(calls[0][2]['dbname'], default_db) def test_non_default_database(self): """Test that if a database is given, the default database is not used""" # Set up the connection params and default database name default_db = 'test_db' actual_db = 'postgres' self.connection_service._service_provider[ WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database = default_db params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': 'someUri', 'type': ConnectionType.DEFAULT, 'connection': { 'options': { 'user': '******', 'password': '******', 'host': 'myserver', 'dbname': actual_db } } }) # If I connect with an empty database name with mock.patch('ossdbtoolsservice.connection.connection_service._build_connection_response'), \ mock.patch('psycopg2.connect', return_value=MockPsycopgConnection()) as mock_psycopg2_connect: self.connection_service.connect(params) # Then psycopg2's connect method was called with the default database calls = mock_psycopg2_connect.mock_calls self.assertEqual(len(calls), 1) self.assertNotEqual(calls[0][2]['dbname'], default_db) self.assertEqual(calls[0][2]['dbname'], actual_db) def test_get_connection_info(self): """Test that get_connection_info returns the ConnectionInfo object corresponding to a connection""" # Set up the test with mock data connection_uri = 'someuri' # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info # Get the connection info actual_connection_info = self.connection_service.get_connection_info( connection_uri) self.assertIs(actual_connection_info, connection_info) def test_get_connection_info_no_connection(self): """Test that get_connection_info returns None when there is no connection for the given owner URI""" # Set up the test with mock data connection_uri = 'someuri' # Get the connection info actual_connection_info = self.connection_service.get_connection_info( connection_uri) self.assertIsNone(actual_connection_info)