def test_init_session_failed_connection(self): # Setup: # ... Create OE service with mock connection service that returns a failed connection response cs = ConnectionService() connect_response = ConnectionCompleteParams() connect_response.error_message = 'Boom!' cs.connect = mock.MagicMock(return_value=connect_response) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) # If: I initialize a session (NOTE: We're bypassing request handler to avoid threading issues) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) oe._session_map[session_uri] = session rc = RequestFlowValidator() rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._initialize_session(rc.request_context, session) # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() self.assertDictEqual(oe._session_map, {})
def test_handle_create_session_threading_fail(self): # Setup: # ... Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # ... Patch the threading to throw patch_mock = mock.MagicMock(side_effect=Exception('Boom!')) patch_path = 'pgsqltoolsservice.object_explorer.object_explorer_service.threading.Thread' with mock.patch(patch_path, patch_mock): # If: I create a new session params, session_uri = _connection_details() rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification( SessionCreatedParameters, SESSION_CREATED_METHOD, lambda param: self._validate_init_error(param, session_uri)) oe._handle_create_session_request(rc.request_context, params) # Then: # ... The error notification should have been returned rc.validate() # ... The session should have been cleaned up self.assertDictEqual(oe._session_map, {})
def test_dmp_capabilities_have_backup_options(self): """Test that the capabilities returned for a DMP capabilities request include backup options""" # Setup: Create a request context with mocked out send_* methods and set up the capabilities service rc = utils.MockRequestContext() capabilities_service = CapabilitiesService() workspace_service = WorkspaceService() capabilities_service._service_provider = utils.get_mock_service_provider( {constants.WORKSPACE_SERVICE_NAME: workspace_service}) # If: I request the dmp capabilities of this server capabilities_service._handle_dmp_capabilities_request(rc, None) # Then: The response should include backup capabilities rc.send_response.assert_called_once() capabilities_result = rc.send_response.mock_calls[0][1][0] features = capabilities_result.capabilities.features backup_options_list = [ feature for feature in features if feature.feature_name == 'backup' ] # There should be exactly one feature containing backup options self.assertEqual(len(backup_options_list), 1) backup_options = backup_options_list[0] # The backup options should be enabled self.assertTrue(backup_options.enabled) # And the backup options should contain at least 1 option self.assertGreater(len(backup_options.options_metadata), 0)
def _handle_er_incomplete_params(method: TEventHandler): # Setup: # ... Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # ... Create a set of invalid parameters to test param_sets = [ None, ExpandParameters.from_dict({ 'session_id': None, 'node_path': '/' }), ExpandParameters.from_dict({ 'session_id': 'session', 'node_path': None }) ] for params in param_sets: # If: I expand with an invalid set of parameters rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) method(oe, rc.request_context, params) # Then: I should get an error response rc.validate()
def setUp(self): """Set up the tests with common connection parameters""" # Set up the mock connection service and connection info self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider( {WORKSPACE_SERVICE_NAME: WorkspaceService()}) self.owner_uri = 'test_uri' self.connection_type = ConnectionType.DEFAULT self.connect_params: ConnectRequestParams = ConnectRequestParams.from_dict( { 'ownerUri': self.owner_uri, 'type': self.connection_type, 'connection': { 'options': {} } }) self.mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) # Mock psycopg2's connect method to store the current cancellation token. This lets us # capture the cancellation token state as it would be during a long-running connection. self.token_store = []
def test_handle_create_session_successful(self): # Setup: # ... Create OE service with mock connection service that returns a successful connection response mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******', port=123) cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) oe._provider = constants.PG_PROVIDER_NAME oe._server = Server # ... Create parameters, session, request context validator params, session_uri = _connection_details() # ... Create validation of success notification def validate_success_notification(response: SessionCreatedParameters): self.assertTrue(response.success) self.assertEqual(response.session_id, session_uri) self.assertIsNone(response.error_message) self.assertIsInstance(response.root_node, NodeInfo) self.assertEqual(response.root_node.label, TEST_DBNAME) self.assertEqual(response.root_node.node_path, session_uri) self.assertEqual(response.root_node.node_type, 'Database') self.assertIsInstance(response.root_node.metadata, ObjectMetadata) self.assertEqual(response.root_node.metadata.urn, oe._session_map[session_uri].server.urn_base) self.assertEqual( response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name) self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database') self.assertFalse(response.root_node.is_leaf) rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification(SessionCreatedParameters, SESSION_CREATED_METHOD, validate_success_notification) # If: I create a session oe._handle_create_session_request(rc.request_context, params) oe._session_map[session_uri].init_task.join() # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() # ... The session should still exist and should have connection and server setup self.assertIn(session_uri, oe._session_map) self.assertIsInstance(oe._session_map[session_uri].server, Server) self.assertTrue(oe._session_map[session_uri].is_ready)
def test_handle_scriptas_successful_operation(self): # NOTE: There's no need to test all types here, the scripter tests should handle this # Setup: # ... Create a scripting service mock_connection = MockPGServerConnection() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider( {CONNECTION_SERVICE_NAME: cs}) # ... Create validation logic for responses def validate_response(response: ScriptAsResponse) -> None: self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI) self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT) # ... Create a scripter with mocked out calls patch_path = 'ossdbtoolsservice.scripting.scripting_service.Scripter' with mock.patch(patch_path) as scripter_patch: mock_scripter: Scripter = Scripter(mock_connection) mock_scripter.script = mock.MagicMock( return_value=TestScriptingService.MOCK_SCRIPT) scripter_patch.return_value = mock_scripter scripting_object = { 'type': 'Table', 'name': 'test_table', 'schema': 'test_schema' } # For each operation supported for operation in ScriptOperation: # If: I request to script rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_response(ScriptAsResponse, validate_response) params = ScriptAsParameters.from_dict({ 'ownerUri': TestScriptingService.MOCK_URI, 'operation': operation, 'scripting_objects': [scripting_object] }) ss._handle_scriptas_request(rc.request_context, params) # Then: # ... The request should have been handled correctly rc.validate() # ... All of the scripter methods should have been called once matches = {operation: 0 for operation in ScriptOperation} for call_args in mock_scripter.script.call_args_list: matches[call_args[0][0]] += 1 for calls in matches.values(): self.assertEqual(calls, 1)
def _preloaded_oe_service( self) -> Tuple[ObjectExplorerService, ObjectExplorerSession, str]: oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) conn_details, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, conn_details) session.server = mock.Mock() session.is_ready = True oe._session_map[session_uri] = session return oe, session, session_uri
def test_mysql_language_flavor(self): """ Test that if provider is MySQL, the service ignores files registered as being for non-MySQL flavors """ # If: I create a new language service pgsql_params = LanguageFlavorChangeParams.from_data( 'file://pguri.sql', 'sql', PG_PROVIDER_NAME) mysql_params = LanguageFlavorChangeParams.from_data( 'file://mysqluri.sql', 'sql', MYSQL_PROVIDER_NAME) mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MSSQL_PROVIDER_NAME) other_params = LanguageFlavorChangeParams.from_data( 'file://other.doc', 'doc', '') # create a mock mysql service provider provider = utils.get_mock_service_provider( provider_name=MYSQL_PROVIDER_NAME) service = LanguageService() service._service_provider = provider # When: I notify of language preferences context: NotificationContext = utils.get_mock_notification_context() service.handle_flavor_change(context, pgsql_params) service.handle_flavor_change(context, mssql_params) service.handle_flavor_change(context, mysql_params) service.handle_flavor_change(context, other_params) # Then: # ... Only non-MySQL SQL files should be ignored context.send_notification.assert_not_called() self.assertFalse(service.is_valid_uri(mssql_params.uri)) self.assertFalse(service.is_valid_uri(pgsql_params.uri)) self.assertFalse(service.is_valid_uri(other_params.uri)) self.assertTrue(service.is_valid_uri(mysql_params.uri)) # When: I change from MSSQL to PGSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', PG_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to not allow intellisense self.assertFalse(service.is_valid_uri(mssql_params.uri)) # When: I change from PGSQL to MYSQL mssql_params = LanguageFlavorChangeParams.from_data( 'file://msuri.sql', 'sql', MYSQL_PROVIDER_NAME) service.handle_flavor_change(context, mssql_params) # Then: the service is updated to allow intellisense self.assertTrue(service.is_valid_uri(mssql_params.uri))
def test_handle_scriptas_missing_params(self): # Setup: Create a scripting service ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider({}) # If: I make a scripting request missing params rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation) ss._handle_scriptas_request(rc.request_context, None) # Then: # ... I should get an error response rc.validate()
def setUp(self): """Set up the tests with a connection service""" self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider( {WORKSPACE_SERVICE_NAME: WorkspaceService()}, provider_name=MYSQL_PROVIDER_NAME) mock_cursor = MockCursor(results=[['5.7.29-log']]) # Set up the mock connection for pymysql's connect method to return self.mock_pymysql_connection = MockPyMySQLConnection( parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }, cursor=mock_cursor)
def test_handle_create_session_missing_params(self): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I create an OE session with missing params rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) oe._handle_create_session_request(rc.request_context, None) # Then: # ... I should get an error response rc.validate() # ... A session should not have been created self.assertDictEqual(oe._session_map, {})
def _handle_er_no_session_match(method: TEventHandler): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I expand a node on a session that doesn't exist rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ExpandParameters.from_dict({ 'session_id': 'session', 'node_path': None }) method(oe, rc.request_context, params) # Then: I should get an error back rc.validate()
def test_create_connection_successful(self): # Setup: mock_connection = MockConnection('test') oe = ObjectExplorerService() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) connection = oe._create_connection(session, 'foo_database') self.assertIsNotNone(connection) self.assertEqual(connection, mock_connection) cs.connect.assert_called_once() cs.get_connection.assert_called_once()
def test_dmp_capabilities_request(self): # Setup: Create a request context with mocked out send_* methods and set up the capabilities service rc = utils.MockRequestContext() capabilities_service = CapabilitiesService() workspace_service = WorkspaceService() capabilities_service._service_provider = utils.get_mock_service_provider( {constants.WORKSPACE_SERVICE_NAME: workspace_service}) # If: I request the dmp capabilities of this server capabilities_service._handle_dmp_capabilities_request(rc, None) # Then: A response should have been sent that is a Capabilities result rc.send_notification.assert_not_called() rc.send_error.assert_not_called() rc.send_response.assert_called_once() self.assertIsInstance(rc.send_response.mock_calls[0][1][0], CapabilitiesResult)
def test_create_connection_failed(self): # Setup: oe = ObjectExplorerService() cs = ConnectionService() connect_response = ConnectionCompleteParams() error = 'Failed' connect_response.error_message = error cs.connect = mock.MagicMock(return_value=connect_response) oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) with self.assertRaises(RuntimeError) as context: oe._create_connection(session, 'foo_database') self.assertEqual(error, str(context.exception)) cs.connect.assert_called_once()
def test_handle_scriptas_invalid_operation(self): # Setup: Create a scripting service mock_connection = {} cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider( {CONNECTION_SERVICE_NAME: cs}) # If: I create an OE session with missing params rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_error(type(None), RequestFlowValidator.basic_error_validation) ss._handle_scriptas_request(rc.request_context, None) # Then: # ... I should get an error response rc.validate()
def test_handle_create_session_session_exists(self): # Setup: Create an OE service and pre-load a session oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) oe._session_map[session_uri] = session # If: I attempt to create an OE session that already exists rc = RequestFlowValidator().add_expected_response( bool, self.assertFalse) oe._handle_create_session_request(rc.request_context, params) # Then: # ... I should get a response as False rc.validate() # ... The old session should remain self.assertIs(oe._session_map[session_uri], session)
def test_handle_create_session_incomplete_params(self): # Setup: Create an OE service oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider({}) # If: I create an OE session for with missing params # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all # scenarios in a different test rc = RequestFlowValidator().add_expected_error( type(None), RequestFlowValidator.basic_error_validation) params = ConnectionDetails.from_data({}) oe._handle_create_session_request(rc.request_context, params) # Then: # ... I should get an error response rc.validate() # ... A session should not have been created self.assertDictEqual(oe._session_map, {})
def setUp(self): # Setup: Create an OE service and add a session to it self.cs = ConnectionService() self.mock_connection = {} self.oe = ObjectExplorerService() params, session_uri = _connection_details() self.session = ObjectExplorerSession(session_uri, params) self.oe._session_map[session_uri] = self.session name = 'dbname' self.mock_server = Server(MockConnection(name)) self.session.server = self.mock_server self.db = Database(self.mock_server, name) self.db._connection = MockConnection(name) self.session.server._child_objects[Database.__name__] = [self.db] self.cs.get_connection = mock.MagicMock( return_value=self.mock_connection) self.cs.disconnect = mock.MagicMock(return_value=True) self.oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: self.cs})
def setUp(self): """Set up the tests with a connection service""" self.connection_service = ConnectionService() self.connection_service._service_provider = utils.get_mock_service_provider({constants.WORKSPACE_SERVICE_NAME: WorkspaceService()})
def setUp(self): """Set up the tests with a disaster recovery service and connection service with mock connection info""" self.disaster_recovery_service = DisasterRecoveryService() self.connection_service = ConnectionService() self.task_service = TaskService() self.disaster_recovery_service._service_provider = utils.get_mock_service_provider( { constants.CONNECTION_SERVICE_NAME: self.connection_service, constants.TASK_SERVICE_NAME: self.task_service }) # Create connection information for use in the tests self.connection_details = ConnectionDetails() self.host = 'test_host' self.dbname = 'test_db' self.username = '******' self.connection_details.options = { 'host': self.host, 'dbname': self.dbname, 'user': self.username, 'port': 5432 } self.test_uri = 'test_uri' self.connection_info = ConnectionInfo(self.test_uri, self.connection_details) # Create backup parameters for the tests self.request_context = utils.MockRequestContext() self.backup_path = 'mock/path/test.sql' self.backup_type = 'sql' self.data_only = False self.no_owner = True self.schema = 'test_schema' self.backup_params = BackupParams.from_dict({ 'ownerUri': self.test_uri, 'backupInfo': { 'type': self.backup_type, 'path': self.backup_path, 'data_only': self.data_only, 'no_owner': self.no_owner, 'schema': self.schema } }) self.restore_path = 'mock/path/test.dump' self.restore_params = RestoreParams.from_dict({ 'ownerUri': self.test_uri, 'options': { 'path': self.restore_path, 'data_only': self.data_only, 'no_owner': self.no_owner, 'schema': self.schema } }) self.pg_dump_exe = 'pg_dump' self.pg_restore_exe = 'pg_restore' # Create the mock task for the tests self.mock_action = mock.Mock() self.mock_task = Task(None, None, None, None, None, self.request_context, self.mock_action) self.mock_task.start = mock.Mock()