def test_generate_uri_missing_params(self):
        # Setup: Create the parameter sets that will be missing a param each
        params = [
            ConnectionDetails.from_data({
                'host': None,
                'dbname': TEST_DBNAME,
                'user': TEST_USER,
                'port': TEST_PORT
            }),
            ConnectionDetails.from_data({
                'host': TEST_HOST,
                'dbname': None,
                'user': TEST_USER,
                'port': TEST_PORT
            }),
            ConnectionDetails.from_data({
                'host': TEST_HOST,
                'dbname': TEST_DBNAME,
                'user': None,
                'port': TEST_PORT
            }),
            ConnectionDetails.from_data({
                'host': TEST_HOST,
                'dbname': TEST_DBNAME,
                'user': TEST_USER,
                'port': None
            })
        ]

        for param_set in params:
            # If: I generate a session URI from params that are missing a value
            # Then: I should get an exception
            with self.assertRaises(Exception):
                ObjectExplorerService._generate_session_uri(
                    param_set, constants.PG_PROVIDER_NAME)
Exemple #2
0
    def server_info_is_cloud_internal(self, host_suffix, is_cloud):
        """Test that the connection response handles cloud connections correctly"""
        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres'
        }
        connection_type = ConnectionType.DEFAULT

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockPsycopgConnection(
            dsn_parameters={
                'host': f'myserver{host_suffix}',
                'dbname': 'postgres',
                'user': '******'
            })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            response = self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))

        # Verify that the response's serverInfo.isCloud attribute is set correctly
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertEqual(response.server_info.is_cloud, is_cloud)
def _connection_details() -> Tuple[ConnectionDetails, str]:
    param = ConnectionDetails()
    param.options = {
        'host': TEST_HOST,
        'dbname': TEST_DBNAME,
        'user': TEST_USER,
        'port': TEST_PORT
    }
    session_uri = ObjectExplorerService._generate_session_uri(
        param, constants.PG_PROVIDER_NAME)
    return param, session_uri
Exemple #4
0
    def test_get_connection_for_existing_connection(self):
        """Test that get_connection returns a connection that already exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockPsycopgConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        # Get the connection without first creating it
        with mock.patch(
                'psycopg2.connect', new=mock.Mock(
                    return_value=mock_connection)) as mock_psycopg2_connect:
            connection = self.connection_service.get_connection(
                connection_uri, connection_type)
            mock_psycopg2_connect.assert_called_once()
        self.assertEqual(connection._conn, mock_connection)
    def setUp(self):
        """Constructor"""
        self.default_connection_key = 'server_db_user'
        self.mock_connection_service = ConnectionService()
        self.mock_server = JSONRPCServer(None, None)
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     PG_PROVIDER_NAME, None)
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails.from_data({})
        self.connection_details.server_name = 'test_host'
        self.connection_details.database_name = 'test_db'
        self.connection_details.user_name = 'user'
        self.expected_context_key = 'test_host|test_db|user'
        self.expected_connection_uri = INTELLISENSE_URI + self.expected_context_key
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create mock CompletionRefresher to avoid calls to create separate thread
        self.refresher_mock = mock.MagicMock()
        self.refresh_method_mock = mock.MagicMock()
        self.refresher_mock.refresh = self.refresh_method_mock
Exemple #6
0
    def test_build_connection_response(self):
        """Test that the connection response is built correctly"""
        # Set up the test with mock data
        server_name = 'testserver'
        db_name = 'testdb'
        user = '******'
        mock_connection = MockPGServerConnection(cur=None,
                                                 host=server_name,
                                                 name=db_name,
                                                 user=user)
        connection_type = ConnectionType.EDIT
        connection_details = ConnectionDetails.from_data(opts={})
        owner_uri = 'test_uri'
        connection_info = ConnectionInfo(owner_uri, connection_details)
        connection_info._connection_map = {connection_type: mock_connection}

        # If I build a connection response for the connection
        response = ossdbtoolsservice.connection.connection_service._build_connection_response(
            connection_info, connection_type)

        # Then the response should have accurate information about the connection
        self.assertEqual(response.owner_uri, owner_uri)
        self.assertEqual(
            response.server_info.server_version,
            str(mock_connection.server_version[0]) + "." +
            str(mock_connection.server_version[1]) + "." +
            str(mock_connection.server_version[2]))
        self.assertEqual(response.server_info.is_cloud, False)
        self.assertEqual(response.connection_summary.server_name, server_name)
        self.assertEqual(response.connection_summary.database_name, db_name)
        self.assertEqual(response.connection_summary.user_name, user)
        self.assertEqual(response.type, connection_type)
Exemple #7
0
    def test_list_databases_handles_query_failure(self):
        """Test that the list databases handler returns an error if the list databases query fails for any reason"""
        # Set up the test with mock data
        mock_query_results = [('database1', ), ('database2', )]
        connection_uri = 'someuri'
        mock_cursor = MockCursor(mock_query_results)
        mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('')
        mock_connection = MockPGServerConnection(cur=mock_cursor,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected
        # databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(
                mock_request_context, params)
        self.assertIsNone(mock_request_context.last_notification_method)
        self.assertIsNone(mock_request_context.last_notification_params)
        self.assertIsNone(mock_request_context.last_response_params)
        self.assertIsNotNone(mock_request_context.last_error_message)
    def test_routing_target_get_nodes_not_empty(self):
        # Setup: Create mock node generator and folder node list
        node1 = NodeInfo()
        node2 = NodeInfo()
        node_generator = mock.MagicMock(return_value=[node1, node2])
        folder_list = [
            session.Folder('Folder1', 'fp1'),
            session.Folder('Folder2', 'fp2')
        ]

        # If: I ask for nodes for a routing target
        rt = session.RoutingTarget(folder_list, node_generator)
        current_path = '/'
        match_params = {}
        object_explorer_session = ObjectExplorerSession(
            'session_id', ConnectionDetails())
        output = rt.get_nodes(False, current_path, object_explorer_session,
                              match_params)

        # Then:
        # ... I should get back a list of nodes
        self.assertIsInstance(output, list)
        for node in output:
            self.assertIsInstance(node, NodeInfo)
        self.assertEqual(len(output), 4)
        self.assertEqual(output[0].node_type, 'Folder')
        self.assertEqual(output[1].node_type, 'Folder')
        self.assertIs(output[2], node1)
        self.assertIs(output[3], node2)

        # ... The node generator should have been called
        node_generator.assert_called_once_with(False, current_path,
                                               object_explorer_session,
                                               match_params)
Exemple #9
0
    def test_list_databases(self):
        """Test that the list databases handler correctly lists the connection's databases"""
        # Set up the test with mock data
        mock_query_results = [('database1', ), ('database2', )]
        connection_uri = 'someuri'
        mock_psycopg_connection = MockPsycopgConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            },
            cursor=MockCursor(mock_query_results))
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)):
            self.connection_service.handle_list_databases(
                mock_request_context, params)
        expected_databases = [result[0] for result in mock_query_results]
        self.assertEqual(
            mock_request_context.last_response_params.database_names,
            expected_databases)
Exemple #10
0
    def test_get_connection_creates_connection(self):
        """Test that get_connection creates a new connection when none exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        with mock.patch(
                'ossdbtoolsservice.driver.connection_manager.ConnectionManager._create_connection',
                new=mock.Mock(
                    return_value=mock_connection)) as mock_psycopg2_connect:
            # Open the connection
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))

            # Get the connection
            connection = self.connection_service.get_connection(
                connection_uri, connection_type)
            self.assertEqual(connection, mock_connection)
            mock_psycopg2_connect.assert_called_once()
Exemple #11
0
    def test_disconnect_all_types(self):
        """Test that the disconnect method calls close on a all open connection types when no type is given"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        connection_type_2 = ConnectionType.EDIT
        mock_connection_1 = MockPGServerConnection(cur=None,
                                                   host='myserver1',
                                                   name='postgres1',
                                                   user='******')
        mock_connection_2 = MockPGServerConnection(cur=None,
                                                   host='myserver2',
                                                   name='postgres2',
                                                   user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type_1,
                                           mock_connection_1)
        old_connection_info.add_connection(connection_type_2,
                                           mock_connection_2)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(
            old_connection_info)
        mock_connection_1.close.assert_called_once()
        mock_connection_2.close.assert_called_once()
        self.assertTrue(response)
    def test_routing_target_get_nodes_empty(self):
        # If: I ask for nodes for an empty routing target
        rt = session.RoutingTarget(None, None)
        output = rt.get_nodes(
            False, '/', ObjectExplorerSession('session_id',
                                              ConnectionDetails()), {})

        # Then: The results should be empty
        self.assertListEqual(output, [])
    def test_routing_invalid_path(self):
        # If: Ask to route a path without a route

        # Then: I should get an exception
        with self.assertRaises(ValueError):
            self.object_explorer_service._route_request(
                False, ObjectExplorerSession('session_id',
                                             ConnectionDetails()),
                '!/invalid!/')
    def test_routing_match(self):
        # If: Ask to route a request that is valid
        output = self.object_explorer_service._route_request(
            False, ObjectExplorerSession('session_id', ConnectionDetails()),
            '/')

        # Then: The output should be a list of nodes
        self.assertIsInstance(output, list)
        for node in output:
            self.assertIsInstance(node, NodeInfo)
Exemple #15
0
    def run_on_connect_callback(self, conn_type: ConnectionType,
                                expect_callback: bool) -> None:
        """Inner function for callback tests that verifies expected behavior given different connection types"""
        callbacks = [MagicMock(), MagicMock()]
        for callback in callbacks:
            self.connection_service.register_on_connect_callback(callback)

        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver',
            'dbname': 'postgres'
        }
        connection_type = conn_type

        # Set up the mock connection for psycopg2's connect method to return
        mock_psycopg_connection = MockPsycopgConnection(
            dsn_parameters={
                'host': f'myserver',
                'dbname': 'postgres',
                'user': '******'
            })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)):
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))
        self.connection_service.get_connection(connection_uri, conn_type)
        # ... The mock config change callbacks should have been called
        for callback in callbacks:
            if (expect_callback):
                callback.assert_called_once()
                # Verify call args match expected
                callargs: ConnectionInfo = callback.call_args[0][0]
                self.assertEqual(callargs.owner_uri, connection_uri)
            else:
                callback.assert_not_called()
    def test_on_connect_sends_notification(self):
        """
        Test that the service sends an intellisense ready notification after handling an on connect notification from the connection service.
        This is a slightly more end-to-end test that verifies calling through to the queue layer
        """
        # If: I create a new language service
        service: LanguageService = self._init_service_with_flow_validator()
        conn_info = ConnectionInfo(
            'file://msuri.sql',
            ConnectionDetails.from_data({
                'host': None,
                'dbname': 'TEST_DBNAME',
                'user': '******'
            }))

        connect_result = mock.MagicMock()
        connect_result.error_message = None
        self.mock_connection_service.get_connection = mock.Mock(
            return_value=mock.MagicMock())
        self.mock_connection_service.connect = mock.MagicMock(
            return_value=connect_result)

        def validate_success_notification(response: IntelliSenseReadyParams):
            self.assertEqual(response.owner_uri, conn_info.owner_uri)

        # When: I notify of a connection complete for a given URI
        self.flow_validator.add_expected_notification(
            IntelliSenseReadyParams, INTELLISENSE_READY_NOTIFICATION,
            validate_success_notification)

        refresher_mock = mock.MagicMock()
        refresh_method_mock = mock.MagicMock()
        refresher_mock.refresh = refresh_method_mock
        patch_path = 'ossdbtoolsservice.language.operations_queue.CompletionRefresher'
        with mock.patch(patch_path) as refresher_patch:
            refresher_patch.return_value = refresher_mock
            task: threading.Thread = service.on_connect(conn_info)
            # And when refresh is "complete"
            refresh_method_mock.assert_called_once()
            callback = refresh_method_mock.call_args[0][0]
            self.assertIsNotNone(callback)
            callback(None)
            # Wait for task to return
            task.join()

        # Then:
        # an intellisense ready notification should be sent for that URI
        self.flow_validator.validate()
        # ... and the scriptparseinfo should be created
        info: ScriptParseInfo = service.get_script_parse_info(
            conn_info.owner_uri)
        self.assertIsNotNone(info)
        # ... and the info should have the connection key set
        self.assertEqual(info.connection_key,
                         OperationsQueue.create_key(conn_info))
    def test_handle_close_session_incomplete_params(self):
        # If: I close an OE session for with missing params
        # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all
        #       scenarios in a different test
        rc = RequestFlowValidator().add_expected_error(
            type(None), RequestFlowValidator.basic_error_validation)
        params = ConnectionDetails.from_data({})
        self.oe._handle_close_session_request(rc.request_context, params)

        # Then:
        # ... I should get an error response
        rc.validate()
Exemple #18
0
    def test_get_connection_info(self):
        """Test that get_connection_info returns the ConnectionInfo object corresponding to a connection"""
        # Set up the test with mock data
        connection_uri = 'someuri'

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[
            connection_uri] = connection_info

        # Get the connection info
        actual_connection_info = self.connection_service.get_connection_info(
            connection_uri)
        self.assertIs(actual_connection_info, connection_info)
    def handle_change_database_request(self, request_context: RequestContext,
                                       params: ChangeDatabaseRequestParams) -> bool:
        """change database of an existing connection or create a new connection
        with default database from input"""
        connection_info: ConnectionInfo = self.get_connection_info(params.owner_uri)

        if connection_info is None:
            return False

        connection_info_params: Dict[str, str] = connection_info.details.options.copy()
        connection_info_params["dbname"] = params.new_database
        connection_details: ConnectionDetails = ConnectionDetails.from_data(connection_info_params)

        connection_request_params: ConnectRequestParams = ConnectRequestParams(connection_details, params.owner_uri, ConnectionType.DEFAULT)
        self.handle_connect_request(request_context, connection_request_params)
Exemple #20
0
    def test_changing_options_disconnects_existing_connection(self):
        """
        Test that the connect method disconnects an existing connection when trying to open the same connection with
        different options
        """
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockPGServerConnection(cur=None,
                                                 host='myserver',
                                                 name='postgres',
                                                 user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Create a different request with the same owner uri
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': {
                    'host': 'myserver',
                    'dbname': 'postgres',
                    'user': '******',
                    'abc': 234
                }
            }
        })

        # Connect with different options, and verify that disconnect was called
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(params)
        mock_connection.close.assert_called_once()
Exemple #21
0
    def test_same_options_uses_existing_connection(self):
        """Test that the connect method uses an existing connection when connecting again with the same options"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_psycopg_connection = MockPsycopgConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })
        mock_server_connection = MockPGServerConnection(cur=None,
                                                        host='myserver',
                                                        name='postgres',
                                                        user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type,
                                           mock_server_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Connect with identical options, and verify that disconnect was not called
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': old_connection_details.options
            }
        })
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_psycopg_connection)
                        ) as mock_psycopg2_connect:
            response = self.connection_service.connect(params)
            mock_psycopg2_connect.assert_not_called()
        mock_psycopg_connection.close.assert_not_called()
        self.assertIsNotNone(response.connection_id)
    def test_handle_create_session_incomplete_params(self):
        # Setup: Create an OE service
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider({})

        # If: I create an OE session for with missing params
        # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all
        #       scenarios in a different test
        rc = RequestFlowValidator().add_expected_error(
            type(None), RequestFlowValidator.basic_error_validation)
        params = ConnectionDetails.from_data({})
        oe._handle_create_session_request(rc.request_context, params)

        # Then:
        # ... I should get an error response
        rc.validate()

        # ... A session should not have been created
        self.assertDictEqual(oe._session_map, {})
    def _create_connection(self, session: ObjectExplorerSession,
                           database_name: str) -> Optional[ServerConnection]:
        conn_service = self._service_provider[
            utils.constants.CONNECTION_SERVICE_NAME]

        options = session.connection_details.options.copy()
        options['dbname'] = database_name
        conn_details = ConnectionDetails.from_data(options)

        key_uri = session.id + database_name
        connect_request = ConnectRequestParams(conn_details, key_uri,
                                               ConnectionType.OBJECT_EXLPORER)
        connect_result = conn_service.connect(connect_request)
        if connect_result.error_message is not None:
            raise RuntimeError(connect_result.error_message)

        connection = conn_service.get_connection(
            key_uri, ConnectionType.OBJECT_EXLPORER)
        return connection
Exemple #24
0
    def test_disconnect_for_invalid_connection(self):
        """Test that the disconnect method returns false when called on a connection that does not exist"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        mock_connection_1 = MockPGServerConnection(cur=None,
                                                   host='myserver1',
                                                   name='postgres1',
                                                   user='******')

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type_1,
                                           mock_connection_1)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(
            old_connection_info, ConnectionType.EDIT)
        mock_connection_1.close.assert_not_called()
        self.assertFalse(response)
    def _handle_create_session_request(self, request_context: RequestContext,
                                       params: ConnectionDetails) -> None:
        """Handle a create object explorer session request"""
        # Step 1: Create the session
        try:
            # Make sure we have the appropriate session params
            utils.validate.is_not_none('params', params)

            # Use the provider's default db if db name was not specified
            if params.database_name is None or params.database_name == '':
                if self._provider == utils.constants.MYSQL_PROVIDER_NAME:
                    params.database_name = self._service_provider[
                        utils.constants.
                        WORKSPACE_SERVICE_NAME].configuration.my_sql.default_database
                elif self._provider == utils.constants.PG_PROVIDER_NAME:
                    params.database_name = self._service_provider[
                        utils.constants.
                        WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database

            # Use the provider's default port if port number was not specified
            if not params.port:
                params.port = utils.constants.DEFAULT_PORT[self._provider]

            # Generate the session ID and create/store the session
            session_id = self._generate_session_uri(params, self._provider)
            session: ObjectExplorerSession = ObjectExplorerSession(
                session_id, params)

            # Add the session to session map in a lock to prevent race conditions between check and add
            with self._session_lock:
                if session_id in self._session_map:
                    # Removed the exception for now. But we need to investigate why we would get this
                    if self._service_provider.logger is not None:
                        self._service_provider.logger.error(
                            f'Object explorer session for {session_id} already exists!'
                        )
                    request_context.send_response(False)
                    return

                self._session_map[session_id] = session

            # Respond that the session was created
            response = CreateSessionResponse(session_id)
            request_context.send_response(response)

        except Exception as e:
            message = f'Failed to create OE session: {str(e)}'
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
            return

        # Step 2: Connect the session and lookup the root node asynchronously
        try:
            session.init_task = threading.Thread(
                target=self._initialize_session,
                args=(request_context, session))
            session.init_task.daemon = True
            session.init_task.start()
        except Exception as e:
            # TODO: Localize
            self._session_created_error(
                request_context, session,
                f'Failed to start OE init task: {str(e)}')
Exemple #26
0
    def setUp(self):
        """Set up the tests with a disaster recovery service and connection service with mock connection info"""
        self.disaster_recovery_service = DisasterRecoveryService()
        self.connection_service = ConnectionService()
        self.task_service = TaskService()
        self.disaster_recovery_service._service_provider = utils.get_mock_service_provider(
            {
                constants.CONNECTION_SERVICE_NAME: self.connection_service,
                constants.TASK_SERVICE_NAME: self.task_service
            })

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails()
        self.host = 'test_host'
        self.dbname = 'test_db'
        self.username = '******'
        self.connection_details.options = {
            'host': self.host,
            'dbname': self.dbname,
            'user': self.username,
            'port': 5432
        }
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create backup parameters for the tests
        self.request_context = utils.MockRequestContext()
        self.backup_path = 'mock/path/test.sql'
        self.backup_type = 'sql'
        self.data_only = False
        self.no_owner = True
        self.schema = 'test_schema'
        self.backup_params = BackupParams.from_dict({
            'ownerUri': self.test_uri,
            'backupInfo': {
                'type': self.backup_type,
                'path': self.backup_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.restore_path = 'mock/path/test.dump'
        self.restore_params = RestoreParams.from_dict({
            'ownerUri': self.test_uri,
            'options': {
                'path': self.restore_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.pg_dump_exe = 'pg_dump'
        self.pg_restore_exe = 'pg_restore'

        # Create the mock task for the tests
        self.mock_action = mock.Mock()
        self.mock_task = Task(None, None, None, None, None,
                              self.request_context, self.mock_action)
        self.mock_task.start = mock.Mock()