def test_disconnect_all_types(self): """Test that the disconnect method calls close on a all open connection types when no type is given""" # Set up the test with mock data connection_uri = 'someuri' connection_type_1 = ConnectionType.DEFAULT connection_type_2 = ConnectionType.EDIT mock_connection_1 = MockPGServerConnection(cur=None, host='myserver1', name='postgres1', user='******') mock_connection_2 = MockPGServerConnection(cur=None, host='myserver2', name='postgres2', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({'abc': 123}) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type_1, mock_connection_1) old_connection_info.add_connection(connection_type_2, mock_connection_2) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Close the connection by calling disconnect response = self.connection_service._close_connections( old_connection_info) mock_connection_1.close.assert_called_once() mock_connection_2.close.assert_called_once() self.assertTrue(response)
def test_build_connection_response(self): """Test that the connection response is built correctly""" # Set up the test with mock data server_name = 'testserver' db_name = 'testdb' user = '******' mock_connection = MockPGServerConnection(cur=None, host=server_name, name=db_name, user=user) connection_type = ConnectionType.EDIT connection_details = ConnectionDetails.from_data(opts={}) owner_uri = 'test_uri' connection_info = ConnectionInfo(owner_uri, connection_details) connection_info._connection_map = {connection_type: mock_connection} # If I build a connection response for the connection response = ossdbtoolsservice.connection.connection_service._build_connection_response( connection_info, connection_type) # Then the response should have accurate information about the connection self.assertEqual(response.owner_uri, owner_uri) self.assertEqual( response.server_info.server_version, str(mock_connection.server_version[0]) + "." + str(mock_connection.server_version[1]) + "." + str(mock_connection.server_version[2])) self.assertEqual(response.server_info.is_cloud, False) self.assertEqual(response.connection_summary.server_name, server_name) self.assertEqual(response.connection_summary.database_name, db_name) self.assertEqual(response.connection_summary.user_name, user) self.assertEqual(response.type, connection_type)
def test_list_databases_handles_query_failure(self): """Test that the list databases handler returns an error if the list databases query fails for any reason""" # Set up the test with mock data mock_query_results = [('database1', ), ('database2', )] connection_uri = 'someuri' mock_cursor = MockCursor(mock_query_results) mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('') mock_connection = MockPGServerConnection(cur=mock_cursor, host='myserver', name='postgres', user='******') mock_request_context = utils.MockRequestContext() # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info # Verify that calling the listdatabases handler returns the expected # databases params = ListDatabasesParams() params.owner_uri = connection_uri with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): self.connection_service.handle_list_databases( mock_request_context, params) self.assertIsNone(mock_request_context.last_notification_method) self.assertIsNone(mock_request_context.last_notification_params) self.assertIsNone(mock_request_context.last_response_params) self.assertIsNotNone(mock_request_context.last_error_message)
def test_get_connection_creates_connection(self): """Test that get_connection creates a new connection when none exists for the given URI and type""" # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.EDIT mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map connection_details = ConnectionDetails.from_data({}) connection_info = ConnectionInfo(connection_uri, connection_details) self.connection_service.owner_to_connection_map[ connection_uri] = connection_info with mock.patch( 'ossdbtoolsservice.driver.connection_manager.ConnectionManager._create_connection', new=mock.Mock( return_value=mock_connection)) as mock_psycopg2_connect: # Open the connection self.connection_service.connect( ConnectRequestParams(connection_details, connection_uri, connection_type)) # Get the connection connection = self.connection_service.get_connection( connection_uri, connection_type) self.assertEqual(connection, mock_connection) mock_psycopg2_connect.assert_called_once()
def setUp(self): """Set up the test by creating a query with multiple batches""" self.statement_list = statement_list = [ 'select version;', 'select * from t1;' ] self.statement_str = ''.join(statement_list) self.query_uri = 'test_uri' self.query = Query( self.query_uri, self.statement_str, QueryExecutionSettings(ExecutionPlanOptions(), ResultSetStorageType.FILE_STORAGE), QueryEvents()) self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')] self.cursor = MockCursor(self.mock_query_results) self.connection = MockPGServerConnection(cur=self.cursor) self.columns_info = [] db_column_id = DbColumn() db_column_id.data_type = 'text' db_column_id.column_name = 'Id' db_column_id.provider = PG_PROVIDER_NAME db_column_value = DbColumn() db_column_value.data_type = 'text' db_column_value.column_name = 'Value' db_column_value.provider = PG_PROVIDER_NAME self.columns_info = [db_column_id, db_column_value] self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)
def test_init(self): # If: I construct a new server object host = 'host' port = '1234' dbname = 'dbname' mock_conn = MockPGServerConnection(None, name=dbname, host=host, port=port) server = Server(mock_conn) # Then: # ... The assigned properties should be assigned self.assertIsInstance(server._conn, MockPGServerConnection) self.assertIsInstance(server.connection, MockPGServerConnection) self.assertIs(server.connection, mock_conn) self.assertEqual(server._host, host) self.assertEqual(server.host, host) self.assertEqual(server._port, port) self.assertEqual(server.port, port) self.assertEqual(server._maintenance_db_name, dbname) self.assertEqual(server.maintenance_db_name, dbname) self.assertTupleEqual(server.version, server._conn.server_version) # ... Recovery options should be a lazily loaded thing self.assertIsInstance(server._recovery_props, NodeLazyPropertyCollection) for key, collection in server._child_objects.items(): # ... The child object collection a NodeCollection self.assertIsInstance(collection, NodeCollection) # ... There should be a property mapped to the node collection prop = getattr(server, inflection.pluralize(key.lower())) self.assertIs(prop, collection)
def test_handle_get_database_info_request(self): """Test that the database info handler responds with the correct database info""" uri = 'test_uri' db_name = 'test_db' user_name = 'test_user' # Set up the request parameters params = GetDatabaseInfoParameters() params.owner_uri = uri request_context = MockRequestContext() # Set up a mock connection and cursor for the test mock_query_results = [(user_name, )] mock_cursor = MockCursor(mock_query_results) mock_connection = MockPGServerConnection(mock_cursor, name=db_name) self.connection_service.get_connection = mock.Mock( return_value=mock_connection) # If I send a get_database_info request self.admin_service._handle_get_database_info_request( request_context, params) # Then the service responded with the expected information response = request_context.last_response_params self.assertIsInstance(response, GetDatabaseInfoResponse) expected_info = {'dbname': db_name, 'owner': user_name, 'size': None} self.assertEqual(response.database_info.options, expected_info) # And the service retrieved the owner name using a query with the database name as a parameter owner_query = "SELECT pg_catalog.pg_get_userbyid(db.datdba) FROM pg_catalog.pg_database db WHERE db.datname = '{}'".format( db_name) mock_cursor.execute.assert_called_once_with(owner_query)
def test_handle_create_session_successful(self): # Setup: # ... Create OE service with mock connection service that returns a successful connection response mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******', port=123) cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe = ObjectExplorerService() oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) oe._provider = constants.PG_PROVIDER_NAME oe._server = Server # ... Create parameters, session, request context validator params, session_uri = _connection_details() # ... Create validation of success notification def validate_success_notification(response: SessionCreatedParameters): self.assertTrue(response.success) self.assertEqual(response.session_id, session_uri) self.assertIsNone(response.error_message) self.assertIsInstance(response.root_node, NodeInfo) self.assertEqual(response.root_node.label, TEST_DBNAME) self.assertEqual(response.root_node.node_path, session_uri) self.assertEqual(response.root_node.node_type, 'Database') self.assertIsInstance(response.root_node.metadata, ObjectMetadata) self.assertEqual(response.root_node.metadata.urn, oe._session_map[session_uri].server.urn_base) self.assertEqual( response.root_node.metadata.name, oe._session_map[session_uri].server.maintenance_db_name) self.assertEqual(response.root_node.metadata.metadata_type_name, 'Database') self.assertFalse(response.root_node.is_leaf) rc = RequestFlowValidator() rc.add_expected_response( CreateSessionResponse, lambda param: self.assertEqual(param.session_id, session_uri)) rc.add_expected_notification(SessionCreatedParameters, SESSION_CREATED_METHOD, validate_success_notification) # If: I create a session oe._handle_create_session_request(rc.request_context, params) oe._session_map[session_uri].init_task.join() # Then: # ... Error notification should have been returned, session should be cleaned up from OE service rc.validate() # ... The session should still exist and should have connection and server setup self.assertIn(session_uri, oe._session_map) self.assertIsInstance(oe._session_map[session_uri].server, Server) self.assertTrue(oe._session_map[session_uri].is_ready)
def test_handle_scriptas_successful_operation(self): # NOTE: There's no need to test all types here, the scripter tests should handle this # Setup: # ... Create a scripting service mock_connection = MockPGServerConnection() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) ss = ScriptingService() ss._service_provider = utils.get_mock_service_provider( {CONNECTION_SERVICE_NAME: cs}) # ... Create validation logic for responses def validate_response(response: ScriptAsResponse) -> None: self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI) self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT) # ... Create a scripter with mocked out calls patch_path = 'ossdbtoolsservice.scripting.scripting_service.Scripter' with mock.patch(patch_path) as scripter_patch: mock_scripter: Scripter = Scripter(mock_connection) mock_scripter.script = mock.MagicMock( return_value=TestScriptingService.MOCK_SCRIPT) scripter_patch.return_value = mock_scripter scripting_object = { 'type': 'Table', 'name': 'test_table', 'schema': 'test_schema' } # For each operation supported for operation in ScriptOperation: # If: I request to script rc: RequestFlowValidator = RequestFlowValidator() rc.add_expected_response(ScriptAsResponse, validate_response) params = ScriptAsParameters.from_dict({ 'ownerUri': TestScriptingService.MOCK_URI, 'operation': operation, 'scripting_objects': [scripting_object] }) ss._handle_scriptas_request(rc.request_context, params) # Then: # ... The request should have been handled correctly rc.validate() # ... All of the scripter methods should have been called once matches = {operation: 0 for operation in ScriptOperation} for call_args in mock_scripter.script.call_args_list: matches[call_args[0][0]] += 1 for calls in matches.values(): self.assertEqual(calls, 1)
def setUp(self): self._cursor = utils.MockCursor(None) self._connection = MockPGServerConnection(cur=self._cursor) self._batch_text = 'Select * from t1' self._batch_id = 1 self._batch_events = BatchEvents() self._selection_data = SelectionData() self._result_set = mock.MagicMock()
def test_get_obj_by_urn_wrong_server(self): # Setup: Create a server object server = Server(MockPGServerConnection()) with self.assertRaises(ValueError): # If: I get an object by its URN with a URN that is invalid for the server # Then: I should get an exception invalid_urn = '//[email protected]:456/Database.123/' server.get_object_by_urn(invalid_urn)
def test_get_obj_by_urn_wrong_collection(self): # Setup: Create a server object server = Server(MockPGServerConnection()) with self.assertRaises(ValueError): # If: I get an object by its URN with a URN that points to an invalid path off the server # Then: I should get an exception invalid_urn = parse.urljoin(server.urn_base, 'Datatype.123/') server.get_object_by_urn(invalid_urn)
def setUp(self): """Set up mock objects for testing the scripting service. Ran before each unit test. """ self.conn = MockPGServerConnection(cur=None, port="8080", host="test", name="test") self.script = scripter.Scripter(self.conn)
def test_get_obj_by_urn_empty(self): # Setup: Create a server object server = Server(MockPGServerConnection()) test_cases = [None, '', '\t \n\r'] for test_case in test_cases: with self.assertRaises(ValueError): # If: I get an object by its URN without providing a URN # Then: I should get an exception server.get_object_by_urn(test_case)
def test_init_not_connected(self): # If: I create a DB that is connected name = 'dbname' mock_conn = Server(MockPGServerConnection(None, name='not_connected')) db = Database(mock_conn, name) # Then: # ... Default validation should pass self._init_validation(db, mock_conn, None, name) # ... The schema node collection should not be defined self.assertIsNotNone(db._schemas) self.assertIsNotNone(db.schemas)
def test_init_connected(self): # If: I create a DB that is connected name = 'dbname' mock_server = Server(MockPGServerConnection(None, name=name)) db = Database(mock_server, name) # Then: # ... Default validation should pass self._init_validation(db, mock_server, None, name) # ... The schema node collection should be defined self.assertIsInstance(db._schemas, NodeCollection) self.assertIs(db.schemas, db._schemas)
def test_get_obj_by_urn_success(self): # Setup: Create a server with a database under it server = Server(MockPGServerConnection()) mock_db = Database(server, 'test_db') mock_db._oid = 123 server._child_objects[Database.__name__] = {123: mock_db} # If: I get an object by its URN urn = parse.urljoin(server.urn_base, '/Database.123/') obj = server.get_object_by_urn(urn) # Then: The object I get back should be the same as the object I provided self.assertIs(obj, mock_db)
def test_urn_base(self): # Setup: # ... Create a server object that has a connection server = Server(MockPGServerConnection()) # If: I get the URN base for the server urn_base = server.urn_base # Then: The urn base should match the expected outcome urn_base_regex = re.compile(r'//(?P<user>.+)@(?P<host>.+):(?P<port>\d+)') urn_base_match = urn_base_regex.match(urn_base) self.assertIsNotNone(urn_base_match) self.assertEqual(urn_base_match.groupdict()['user'], server.connection.user_name) self.assertEqual(urn_base_match.groupdict()['host'], server.host) self.assertEqual(urn_base_match.groupdict()['port'], server.port)
def setUp(self): self._smo_metadata_factory = SmoEditTableMetadataFactory() self._connection = MockPGServerConnection(cur=None, port="8080", host="test", name="test", user="******") self._server = Server(self._connection) self._schema_name = 'public' self._table_name = 'Employee' self._view_name = 'Vendor' self._table_object_type = 'TABLE' self._view_object_type = 'VIEW' self._columns = [ Column(self._server, "testTable", 'testName', 'testDatatype') ]
def test_changing_options_disconnects_existing_connection(self): """ Test that the connect method disconnects an existing connection when trying to open the same connection with different options """ # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.DEFAULT mock_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({ 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 123 }) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type, mock_connection) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Create a different request with the same owner uri params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': connection_uri, 'type': connection_type, 'connection': { 'options': { 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 234 } } }) # Connect with different options, and verify that disconnect was called with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)): self.connection_service.connect(params) mock_connection.close.assert_called_once()
def test_create_connection_successful(self): # Setup: mock_connection = MockPGServerConnection() oe = ObjectExplorerService() cs = ConnectionService() cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams()) cs.get_connection = mock.MagicMock(return_value=mock_connection) oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: cs}) params, session_uri = _connection_details() session = ObjectExplorerSession(session_uri, params) connection = oe._create_connection(session, 'foo_database') self.assertIsNotNone(connection) self.assertEqual(connection, mock_connection) cs.connect.assert_called_once() cs.get_connection.assert_called_once()
def test_metadata_list_request(self): """Test that the metadata list handler properly starts a thread to list metadata and responds with the list""" # Set up the parameters and mocks for the request expected_metadata = [ ObjectMetadata(schema='schema1', name='table1', metadata_type=MetadataType.TABLE), ObjectMetadata(schema='schema1', name='view1', metadata_type=MetadataType.VIEW), ObjectMetadata(schema='schema1', name='function1', metadata_type=MetadataType.FUNCTION), ObjectMetadata(schema='schema1', name='table2', metadata_type=MetadataType.TABLE), ObjectMetadata(schema='schema2', name='view1', metadata_type=MetadataType.VIEW), ObjectMetadata(schema='schema2', name='function1', metadata_type=MetadataType.FUNCTION), ] metadata_type_to_str_map = { MetadataType.TABLE: 't', MetadataType.VIEW: 'v', MetadataType.FUNCTION: 'f' } # Query results have schema_name, object_name, and object_type columns in that order list_query_result = [(metadata.schema, metadata.name, metadata_type_to_str_map[metadata.metadata_type]) for metadata in expected_metadata] mock_cursor = MockCursor(list_query_result) mock_connection = MockPGServerConnection(cur=mock_cursor) self.connection_service.get_connection = mock.Mock(return_value=mock_connection) request_context = MockRequestContext() params = MetadataListParameters() params.owner_uri = self.test_uri mock_thread = MockThread() with mock.patch('threading.Thread', new=mock.Mock(side_effect=mock_thread.initialize_target)): # If I call the metadata list request handler self.metadata_service._handle_metadata_list_request(request_context, params) # Then the worker thread was kicked off self.assertEqual(mock_thread.target, self.metadata_service._metadata_list_worker) mock_thread.start.assert_called_once() # And the worker retrieved the correct connection and executed a query on it self.connection_service.get_connection.assert_called_once_with(self.test_uri, ConnectionType.DEFAULT) mock_cursor.execute.assert_called_once() # And the handler responded with the expected results self.assertIsNone(request_context.last_error_message) self.assertIsNone(request_context.last_notification_method) response = request_context.last_response_params self.assertIsInstance(response, MetadataListResponse) for index, actual_metadata in enumerate(response.metadata): self.assertIsInstance(actual_metadata, ObjectMetadata) self.assertEqual(actual_metadata.schema, expected_metadata[index].schema) self.assertEqual(actual_metadata.name, expected_metadata[index].name) self.assertEqual(actual_metadata.metadata_type, expected_metadata[index].metadata_type)
def test_same_options_uses_existing_connection(self): """Test that the connect method uses an existing connection when connecting again with the same options""" # Set up the test with mock data connection_uri = 'someuri' connection_type = ConnectionType.DEFAULT mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={ 'host': 'myserver', 'dbname': 'postgres', 'user': '******' }) mock_server_connection = MockPGServerConnection(cur=None, host='myserver', name='postgres', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({ 'host': 'myserver', 'dbname': 'postgres', 'user': '******', 'abc': 123 }) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type, mock_server_connection) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Connect with identical options, and verify that disconnect was not called params: ConnectRequestParams = ConnectRequestParams.from_dict({ 'ownerUri': connection_uri, 'type': connection_type, 'connection': { 'options': old_connection_details.options } }) with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_psycopg_connection) ) as mock_psycopg2_connect: response = self.connection_service.connect(params) mock_psycopg2_connect.assert_not_called() mock_psycopg_connection.close.assert_not_called() self.assertIsNotNone(response.connection_id)
def test_maintenance_db(self): # Setup: # ... Create a server object that has a connection obj = Server(MockPGServerConnection(None, name='dbname')) # ... Mock out the database lazy loader's indexer mock_db = {} mock_db_collection = mock.Mock() mock_db_collection.__getitem__ = mock.MagicMock(return_value=mock_db) obj._child_objects[Database.__name__] = mock_db_collection # If: I retrieve the maintenance db for the server maintenance_db = obj.maintenance_db # Then: # ... It must have come from the mock handler self.assertIs(maintenance_db, mock_db) obj._child_objects[Database.__name__].__getitem__.assert_called_once_with('dbname')
def test_refresh(self): # Setup: # ... Create a server object that has a connection obj = Server(MockPGServerConnection()) # ... Mock out the reset methods on the various collections obj.databases.reset = mock.MagicMock() obj.roles.reset = mock.MagicMock() obj.tablespaces.reset = mock.MagicMock() obj._recovery_props.reset = mock.MagicMock() # If: I refresh the server obj.refresh() # Then: The collections should have been reset obj.databases.reset.assert_called_once() obj.roles.reset.assert_called_once() obj.tablespaces.reset.assert_called_once() obj._recovery_props.reset.assert_called_once()
def setUp(self): # Setup: Create an OE service and add a session to it self.cs = ConnectionService() self.mock_connection = {} self.oe = ObjectExplorerService() params, session_uri = _connection_details() self.session = ObjectExplorerSession(session_uri, params) self.oe._session_map[session_uri] = self.session name = 'dbname' self.mock_server = Server(MockPGServerConnection()) self.session.server = self.mock_server self.db = Database(self.mock_server, name) self.db._connection = self.mock_server._conn self.session.server._child_objects[Database.__name__] = [self.db] self.cs.get_connection = mock.MagicMock( return_value=self.mock_connection) self.cs.disconnect = mock.MagicMock(return_value=True) self.oe._service_provider = utils.get_mock_service_provider( {constants.CONNECTION_SERVICE_NAME: self.cs})
def test_disconnect_for_invalid_connection(self): """Test that the disconnect method returns false when called on a connection that does not exist""" # Set up the test with mock data connection_uri = 'someuri' connection_type_1 = ConnectionType.DEFAULT mock_connection_1 = MockPGServerConnection(cur=None, host='myserver1', name='postgres1', user='******') # Insert a ConnectionInfo object into the connection service's map old_connection_details = ConnectionDetails.from_data({'abc': 123}) old_connection_info = ConnectionInfo(connection_uri, old_connection_details) old_connection_info.add_connection(connection_type_1, mock_connection_1) self.connection_service.owner_to_connection_map[ connection_uri] = old_connection_info # Close the connection by calling disconnect response = self.connection_service._close_connections( old_connection_info, ConnectionType.EDIT) mock_connection_1.close.assert_not_called() self.assertFalse(response)