예제 #1
0
    def test_disconnect_all_types(self):
        """Test that the disconnect method calls close on a all open connection types when no type is given"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        connection_type_2 = ConnectionType.EDIT
        mock_connection_1 = MockPGServerConnection(cur=None,
                                                   host='myserver1',
                                                   name='postgres1',
                                                   user='******')
        mock_connection_2 = MockPGServerConnection(cur=None,
                                                   host='myserver2',
                                                   name='postgres2',
                                                   user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type_1,
                                           mock_connection_1)
        old_connection_info.add_connection(connection_type_2,
                                           mock_connection_2)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(
            old_connection_info)
        mock_connection_1.close.assert_called_once()
        mock_connection_2.close.assert_called_once()
        self.assertTrue(response)
예제 #2
0
    def test_build_connection_response(self):
        """Test that the connection response is built correctly"""
        # Set up the test with mock data
        server_name = 'testserver'
        db_name = 'testdb'
        user = '******'
        mock_connection = MockPGServerConnection(cur=None,
                                                 host=server_name,
                                                 name=db_name,
                                                 user=user)
        connection_type = ConnectionType.EDIT
        connection_details = ConnectionDetails.from_data(opts={})
        owner_uri = 'test_uri'
        connection_info = ConnectionInfo(owner_uri, connection_details)
        connection_info._connection_map = {connection_type: mock_connection}

        # If I build a connection response for the connection
        response = ossdbtoolsservice.connection.connection_service._build_connection_response(
            connection_info, connection_type)

        # Then the response should have accurate information about the connection
        self.assertEqual(response.owner_uri, owner_uri)
        self.assertEqual(
            response.server_info.server_version,
            str(mock_connection.server_version[0]) + "." +
            str(mock_connection.server_version[1]) + "." +
            str(mock_connection.server_version[2]))
        self.assertEqual(response.server_info.is_cloud, False)
        self.assertEqual(response.connection_summary.server_name, server_name)
        self.assertEqual(response.connection_summary.database_name, db_name)
        self.assertEqual(response.connection_summary.user_name, user)
        self.assertEqual(response.type, connection_type)
예제 #3
0
    def test_list_databases_handles_query_failure(self):
        """Test that the list databases handler returns an error if the list databases query fails for any reason"""
        # Set up the test with mock data
        mock_query_results = [('database1', ), ('database2', )]
        connection_uri = 'someuri'
        mock_cursor = MockCursor(mock_query_results)
        mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('')
        mock_connection = MockPGServerConnection(cur=mock_cursor,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected
        # databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(
                mock_request_context, params)
        self.assertIsNone(mock_request_context.last_notification_method)
        self.assertIsNone(mock_request_context.last_notification_params)
        self.assertIsNone(mock_request_context.last_response_params)
        self.assertIsNotNone(mock_request_context.last_error_message)
예제 #4
0
    def test_get_connection_creates_connection(self):
        """Test that get_connection creates a new connection when none exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        with mock.patch(
                'ossdbtoolsservice.driver.connection_manager.ConnectionManager._create_connection',
                new=mock.Mock(
                    return_value=mock_connection)) as mock_psycopg2_connect:
            # Open the connection
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))

            # Get the connection
            connection = self.connection_service.get_connection(
                connection_uri, connection_type)
            self.assertEqual(connection, mock_connection)
            mock_psycopg2_connect.assert_called_once()
예제 #5
0
    def setUp(self):
        """Set up the test by creating a query with multiple batches"""
        self.statement_list = statement_list = [
            'select version;', 'select * from t1;'
        ]
        self.statement_str = ''.join(statement_list)
        self.query_uri = 'test_uri'
        self.query = Query(
            self.query_uri, self.statement_str,
            QueryExecutionSettings(ExecutionPlanOptions(),
                                   ResultSetStorageType.FILE_STORAGE),
            QueryEvents())

        self.mock_query_results = [('Id1', 'Value1'), ('Id2', 'Value2')]
        self.cursor = MockCursor(self.mock_query_results)
        self.connection = MockPGServerConnection(cur=self.cursor)

        self.columns_info = []
        db_column_id = DbColumn()
        db_column_id.data_type = 'text'
        db_column_id.column_name = 'Id'
        db_column_id.provider = PG_PROVIDER_NAME
        db_column_value = DbColumn()
        db_column_value.data_type = 'text'
        db_column_value.column_name = 'Value'
        db_column_value.provider = PG_PROVIDER_NAME
        self.columns_info = [db_column_id, db_column_value]
        self.get_columns_info_mock = mock.Mock(return_value=self.columns_info)
예제 #6
0
    def test_init(self):
        # If: I construct a new server object
        host = 'host'
        port = '1234'
        dbname = 'dbname'
        mock_conn = MockPGServerConnection(None, name=dbname, host=host, port=port)
        server = Server(mock_conn)

        # Then:
        # ... The assigned properties should be assigned
        self.assertIsInstance(server._conn, MockPGServerConnection)
        self.assertIsInstance(server.connection, MockPGServerConnection)
        self.assertIs(server.connection, mock_conn)
        self.assertEqual(server._host, host)
        self.assertEqual(server.host, host)
        self.assertEqual(server._port, port)
        self.assertEqual(server.port, port)
        self.assertEqual(server._maintenance_db_name, dbname)
        self.assertEqual(server.maintenance_db_name, dbname)
        self.assertTupleEqual(server.version, server._conn.server_version)

        # ... Recovery options should be a lazily loaded thing
        self.assertIsInstance(server._recovery_props, NodeLazyPropertyCollection)

        for key, collection in server._child_objects.items():
            # ... The child object collection a NodeCollection
            self.assertIsInstance(collection, NodeCollection)

            # ... There should be a property mapped to the node collection
            prop = getattr(server, inflection.pluralize(key.lower()))
            self.assertIs(prop, collection)
예제 #7
0
    def test_handle_get_database_info_request(self):
        """Test that the database info handler responds with the correct database info"""
        uri = 'test_uri'
        db_name = 'test_db'
        user_name = 'test_user'
        # Set up the request parameters
        params = GetDatabaseInfoParameters()
        params.owner_uri = uri
        request_context = MockRequestContext()

        # Set up a mock connection and cursor for the test
        mock_query_results = [(user_name, )]
        mock_cursor = MockCursor(mock_query_results)
        mock_connection = MockPGServerConnection(mock_cursor, name=db_name)
        self.connection_service.get_connection = mock.Mock(
            return_value=mock_connection)

        # If I send a get_database_info request
        self.admin_service._handle_get_database_info_request(
            request_context, params)

        # Then the service responded with the expected information
        response = request_context.last_response_params
        self.assertIsInstance(response, GetDatabaseInfoResponse)
        expected_info = {'dbname': db_name, 'owner': user_name, 'size': None}
        self.assertEqual(response.database_info.options, expected_info)

        # And the service retrieved the owner name using a query with the database name as a parameter
        owner_query = "SELECT pg_catalog.pg_get_userbyid(db.datdba) FROM pg_catalog.pg_database db WHERE db.datname = '{}'".format(
            db_name)
        mock_cursor.execute.assert_called_once_with(owner_query)
    def test_handle_create_session_successful(self):
        # Setup:
        # ... Create OE service with mock connection service that returns a successful connection response
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******',
                                                 port=123)
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        oe._provider = constants.PG_PROVIDER_NAME
        oe._server = Server

        # ... Create parameters, session, request context validator
        params, session_uri = _connection_details()

        # ... Create validation of success notification
        def validate_success_notification(response: SessionCreatedParameters):
            self.assertTrue(response.success)
            self.assertEqual(response.session_id, session_uri)
            self.assertIsNone(response.error_message)

            self.assertIsInstance(response.root_node, NodeInfo)
            self.assertEqual(response.root_node.label, TEST_DBNAME)
            self.assertEqual(response.root_node.node_path, session_uri)
            self.assertEqual(response.root_node.node_type, 'Database')
            self.assertIsInstance(response.root_node.metadata, ObjectMetadata)
            self.assertEqual(response.root_node.metadata.urn,
                             oe._session_map[session_uri].server.urn_base)
            self.assertEqual(
                response.root_node.metadata.name,
                oe._session_map[session_uri].server.maintenance_db_name)
            self.assertEqual(response.root_node.metadata.metadata_type_name,
                             'Database')
            self.assertFalse(response.root_node.is_leaf)

        rc = RequestFlowValidator()
        rc.add_expected_response(
            CreateSessionResponse,
            lambda param: self.assertEqual(param.session_id, session_uri))
        rc.add_expected_notification(SessionCreatedParameters,
                                     SESSION_CREATED_METHOD,
                                     validate_success_notification)

        # If: I create a session
        oe._handle_create_session_request(rc.request_context, params)
        oe._session_map[session_uri].init_task.join()

        # Then:
        # ... Error notification should have been returned, session should be cleaned up from OE service
        rc.validate()

        # ... The session should still exist and should have connection and server setup
        self.assertIn(session_uri, oe._session_map)
        self.assertIsInstance(oe._session_map[session_uri].server, Server)
        self.assertTrue(oe._session_map[session_uri].is_ready)
예제 #9
0
    def test_handle_scriptas_successful_operation(self):
        # NOTE: There's no need to test all types here, the scripter tests should handle this

        # Setup:
        # ... Create a scripting service
        mock_connection = MockPGServerConnection()
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        ss = ScriptingService()
        ss._service_provider = utils.get_mock_service_provider(
            {CONNECTION_SERVICE_NAME: cs})

        # ... Create validation logic for responses
        def validate_response(response: ScriptAsResponse) -> None:
            self.assertEqual(response.owner_uri, TestScriptingService.MOCK_URI)
            self.assertEqual(response.script, TestScriptingService.MOCK_SCRIPT)

        # ... Create a scripter with mocked out calls
        patch_path = 'ossdbtoolsservice.scripting.scripting_service.Scripter'
        with mock.patch(patch_path) as scripter_patch:
            mock_scripter: Scripter = Scripter(mock_connection)
            mock_scripter.script = mock.MagicMock(
                return_value=TestScriptingService.MOCK_SCRIPT)
            scripter_patch.return_value = mock_scripter

            scripting_object = {
                'type': 'Table',
                'name': 'test_table',
                'schema': 'test_schema'
            }

            # For each operation supported
            for operation in ScriptOperation:
                # If: I request to script
                rc: RequestFlowValidator = RequestFlowValidator()
                rc.add_expected_response(ScriptAsResponse, validate_response)

                params = ScriptAsParameters.from_dict({
                    'ownerUri':
                    TestScriptingService.MOCK_URI,
                    'operation':
                    operation,
                    'scripting_objects': [scripting_object]
                })

                ss._handle_scriptas_request(rc.request_context, params)

                # Then:
                # ... The request should have been handled correctly
                rc.validate()

            # ... All of the scripter methods should have been called once
            matches = {operation: 0 for operation in ScriptOperation}
            for call_args in mock_scripter.script.call_args_list:
                matches[call_args[0][0]] += 1

            for calls in matches.values():
                self.assertEqual(calls, 1)
예제 #10
0
 def setUp(self):
     self._cursor = utils.MockCursor(None)
     self._connection = MockPGServerConnection(cur=self._cursor)
     self._batch_text = 'Select * from t1'
     self._batch_id = 1
     self._batch_events = BatchEvents()
     self._selection_data = SelectionData()
     self._result_set = mock.MagicMock()
예제 #11
0
    def test_get_obj_by_urn_wrong_server(self):
        # Setup: Create a server object
        server = Server(MockPGServerConnection())

        with self.assertRaises(ValueError):
            # If: I get an object by its URN with a URN that is invalid for the server
            # Then: I should get an exception
            invalid_urn = '//[email protected]:456/Database.123/'
            server.get_object_by_urn(invalid_urn)
예제 #12
0
    def test_get_obj_by_urn_wrong_collection(self):
        # Setup: Create a server object
        server = Server(MockPGServerConnection())

        with self.assertRaises(ValueError):
            # If: I get an object by its URN with a URN that points to an invalid path off the server
            # Then: I should get an exception
            invalid_urn = parse.urljoin(server.urn_base, 'Datatype.123/')
            server.get_object_by_urn(invalid_urn)
예제 #13
0
 def setUp(self):
     """Set up mock objects for testing the scripting service.
     Ran before each unit test.
     """
     self.conn = MockPGServerConnection(cur=None,
                                        port="8080",
                                        host="test",
                                        name="test")
     self.script = scripter.Scripter(self.conn)
예제 #14
0
    def test_get_obj_by_urn_empty(self):
        # Setup: Create a server object
        server = Server(MockPGServerConnection())

        test_cases = [None, '', '\t \n\r']
        for test_case in test_cases:
            with self.assertRaises(ValueError):
                # If: I get an object by its URN without providing a URN
                # Then: I should get an exception
                server.get_object_by_urn(test_case)
예제 #15
0
    def test_init_not_connected(self):
        # If: I create a DB that is connected
        name = 'dbname'
        mock_conn = Server(MockPGServerConnection(None, name='not_connected'))
        db = Database(mock_conn, name)

        # Then:
        # ... Default validation should pass
        self._init_validation(db, mock_conn, None, name)

        # ... The schema node collection should not be defined
        self.assertIsNotNone(db._schemas)
        self.assertIsNotNone(db.schemas)
예제 #16
0
    def test_init_connected(self):
        # If: I create a DB that is connected
        name = 'dbname'
        mock_server = Server(MockPGServerConnection(None, name=name))
        db = Database(mock_server, name)

        # Then:
        # ... Default validation should pass
        self._init_validation(db, mock_server, None, name)

        # ... The schema node collection should be defined
        self.assertIsInstance(db._schemas, NodeCollection)
        self.assertIs(db.schemas, db._schemas)
예제 #17
0
    def test_get_obj_by_urn_success(self):
        # Setup: Create a server with a database under it
        server = Server(MockPGServerConnection())
        mock_db = Database(server, 'test_db')
        mock_db._oid = 123
        server._child_objects[Database.__name__] = {123: mock_db}

        # If: I get an object by its URN
        urn = parse.urljoin(server.urn_base, '/Database.123/')
        obj = server.get_object_by_urn(urn)

        # Then: The object I get back should be the same as the object I provided
        self.assertIs(obj, mock_db)
예제 #18
0
    def test_urn_base(self):
        # Setup:
        # ... Create a server object that has a connection
        server = Server(MockPGServerConnection())

        # If: I get the URN base for the server
        urn_base = server.urn_base

        # Then: The urn base should match the expected outcome
        urn_base_regex = re.compile(r'//(?P<user>.+)@(?P<host>.+):(?P<port>\d+)')
        urn_base_match = urn_base_regex.match(urn_base)
        self.assertIsNotNone(urn_base_match)
        self.assertEqual(urn_base_match.groupdict()['user'], server.connection.user_name)
        self.assertEqual(urn_base_match.groupdict()['host'], server.host)
        self.assertEqual(urn_base_match.groupdict()['port'], server.port)
 def setUp(self):
     self._smo_metadata_factory = SmoEditTableMetadataFactory()
     self._connection = MockPGServerConnection(cur=None,
                                               port="8080",
                                               host="test",
                                               name="test",
                                               user="******")
     self._server = Server(self._connection)
     self._schema_name = 'public'
     self._table_name = 'Employee'
     self._view_name = 'Vendor'
     self._table_object_type = 'TABLE'
     self._view_object_type = 'VIEW'
     self._columns = [
         Column(self._server, "testTable", 'testName', 'testDatatype')
     ]
예제 #20
0
    def test_changing_options_disconnects_existing_connection(self):
        """
        Test that the connect method disconnects an existing connection when trying to open the same connection with
        different options
        """
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Create a different request with the same owner uri
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': {
                    'host': 'myserver',
                    'dbname': 'postgres',
                    'user': '******',
                    'abc': 234
                }
            }
        })

        # Connect with different options, and verify that disconnect was called
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(params)
        mock_connection.close.assert_called_once()
    def test_create_connection_successful(self):
        # Setup:
        mock_connection = MockPGServerConnection()
        oe = ObjectExplorerService()
        cs = ConnectionService()
        cs.connect = mock.MagicMock(return_value=ConnectionCompleteParams())
        cs.get_connection = mock.MagicMock(return_value=mock_connection)
        oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: cs})
        params, session_uri = _connection_details()
        session = ObjectExplorerSession(session_uri, params)
        connection = oe._create_connection(session, 'foo_database')

        self.assertIsNotNone(connection)
        self.assertEqual(connection, mock_connection)
        cs.connect.assert_called_once()
        cs.get_connection.assert_called_once()
예제 #22
0
    def test_metadata_list_request(self):
        """Test that the metadata list handler properly starts a thread to list metadata and responds with the list"""
        # Set up the parameters and mocks for the request
        expected_metadata = [
            ObjectMetadata(schema='schema1', name='table1', metadata_type=MetadataType.TABLE),
            ObjectMetadata(schema='schema1', name='view1', metadata_type=MetadataType.VIEW),
            ObjectMetadata(schema='schema1', name='function1', metadata_type=MetadataType.FUNCTION),
            ObjectMetadata(schema='schema1', name='table2', metadata_type=MetadataType.TABLE),
            ObjectMetadata(schema='schema2', name='view1', metadata_type=MetadataType.VIEW),
            ObjectMetadata(schema='schema2', name='function1', metadata_type=MetadataType.FUNCTION),
        ]

        metadata_type_to_str_map = {
            MetadataType.TABLE: 't',
            MetadataType.VIEW: 'v',
            MetadataType.FUNCTION: 'f'
        }

        # Query results have schema_name, object_name, and object_type columns in that order
        list_query_result = [(metadata.schema, metadata.name, metadata_type_to_str_map[metadata.metadata_type]) for metadata in expected_metadata]
        mock_cursor = MockCursor(list_query_result)
        mock_connection = MockPGServerConnection(cur=mock_cursor)
        self.connection_service.get_connection = mock.Mock(return_value=mock_connection)
        request_context = MockRequestContext()
        params = MetadataListParameters()
        params.owner_uri = self.test_uri
        mock_thread = MockThread()
        with mock.patch('threading.Thread', new=mock.Mock(side_effect=mock_thread.initialize_target)):
            # If I call the metadata list request handler
            self.metadata_service._handle_metadata_list_request(request_context, params)
            # Then the worker thread was kicked off
            self.assertEqual(mock_thread.target, self.metadata_service._metadata_list_worker)
            mock_thread.start.assert_called_once()
        # And the worker retrieved the correct connection and executed a query on it
        self.connection_service.get_connection.assert_called_once_with(self.test_uri, ConnectionType.DEFAULT)
        mock_cursor.execute.assert_called_once()
        # And the handler responded with the expected results
        self.assertIsNone(request_context.last_error_message)
        self.assertIsNone(request_context.last_notification_method)
        response = request_context.last_response_params
        self.assertIsInstance(response, MetadataListResponse)
        for index, actual_metadata in enumerate(response.metadata):
            self.assertIsInstance(actual_metadata, ObjectMetadata)
            self.assertEqual(actual_metadata.schema, expected_metadata[index].schema)
            self.assertEqual(actual_metadata.name, expected_metadata[index].name)
            self.assertEqual(actual_metadata.metadata_type, expected_metadata[index].metadata_type)
예제 #23
0
    def test_same_options_uses_existing_connection(self):
        """Test that the connect method uses an existing connection when connecting again with the same options"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })
        mock_server_connection = MockPGServerConnection(cur=None,
                                                        host='myserver',
                                                        name='postgres',
                                                        user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type,
                                           mock_server_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Connect with identical options, and verify that disconnect was not called
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': old_connection_details.options
            }
        })
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)
                        ) as mock_psycopg2_connect:
            response = self.connection_service.connect(params)
            mock_psycopg2_connect.assert_not_called()
        mock_psycopg_connection.close.assert_not_called()
        self.assertIsNotNone(response.connection_id)
예제 #24
0
    def test_maintenance_db(self):
        # Setup:
        # ... Create a server object that has a connection
        obj = Server(MockPGServerConnection(None, name='dbname'))

        # ... Mock out the database lazy loader's indexer
        mock_db = {}
        mock_db_collection = mock.Mock()
        mock_db_collection.__getitem__ = mock.MagicMock(return_value=mock_db)
        obj._child_objects[Database.__name__] = mock_db_collection

        # If: I retrieve the maintenance db for the server
        maintenance_db = obj.maintenance_db

        # Then:
        # ... It must have come from the mock handler
        self.assertIs(maintenance_db, mock_db)
        obj._child_objects[Database.__name__].__getitem__.assert_called_once_with('dbname')
예제 #25
0
    def test_refresh(self):
        # Setup:
        # ... Create a server object that has a connection
        obj = Server(MockPGServerConnection())

        # ... Mock out the reset methods on the various collections
        obj.databases.reset = mock.MagicMock()
        obj.roles.reset = mock.MagicMock()
        obj.tablespaces.reset = mock.MagicMock()
        obj._recovery_props.reset = mock.MagicMock()

        # If: I refresh the server
        obj.refresh()

        # Then: The collections should have been reset
        obj.databases.reset.assert_called_once()
        obj.roles.reset.assert_called_once()
        obj.tablespaces.reset.assert_called_once()
        obj._recovery_props.reset.assert_called_once()
    def setUp(self):
        # Setup: Create an OE service and add a session to it
        self.cs = ConnectionService()
        self.mock_connection = {}
        self.oe = ObjectExplorerService()
        params, session_uri = _connection_details()
        self.session = ObjectExplorerSession(session_uri, params)
        self.oe._session_map[session_uri] = self.session
        name = 'dbname'
        self.mock_server = Server(MockPGServerConnection())
        self.session.server = self.mock_server
        self.db = Database(self.mock_server, name)
        self.db._connection = self.mock_server._conn
        self.session.server._child_objects[Database.__name__] = [self.db]
        self.cs.get_connection = mock.MagicMock(
            return_value=self.mock_connection)

        self.cs.disconnect = mock.MagicMock(return_value=True)
        self.oe._service_provider = utils.get_mock_service_provider(
            {constants.CONNECTION_SERVICE_NAME: self.cs})
예제 #27
0
    def test_disconnect_for_invalid_connection(self):
        """Test that the disconnect method returns false when called on a connection that does not exist"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        mock_connection_1 = MockPGServerConnection(cur=None,
                                                   host='myserver1',
                                                   name='postgres1',
                                                   user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type_1,
                                           mock_connection_1)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(
            old_connection_info, ConnectionType.EDIT)
        mock_connection_1.close.assert_not_called()
        self.assertFalse(response)