예제 #1
0
def get_mock_service_provider(
        service_map: dict = None,
        provider_name: Optional[str] = PG_PROVIDER_NAME) -> ServiceProvider:
    """
    Generates a ServiceProvider with the given services

    :param service_map: A dictionary mapping service names to services
    """
    provider = ServiceProvider(None, {}, provider_name, get_mock_logger())
    if service_map is not None:
        provider._services = service_map
    provider._is_initialized = True
    return provider
    def test_register(self):
        """Test registration of the service"""
        # Setup:
        # ... Create a mock service provider
        server: JSONRPCServer = JSONRPCServer(None, None)
        server.set_notification_handler = mock.MagicMock()
        server.set_request_handler = mock.MagicMock()
        provider: ServiceProvider = ServiceProvider(
            server, {constants.CONNECTION_SERVICE_NAME: ConnectionService},
            PG_PROVIDER_NAME, utils.get_mock_logger())
        provider._is_initialized = True
        conn_service: ConnectionService = provider[
            constants.CONNECTION_SERVICE_NAME]
        self.assertEqual(0, len(conn_service._on_connect_callbacks))

        # If: I register a language service
        service: LanguageService = LanguageService()
        service.register(provider)

        # Then:
        # ... The notifications should have been registered
        server.set_notification_handler.assert_called()
        server.set_request_handler.assert_called()
        self.assertEqual(1, len(conn_service._on_connect_callbacks))
        self.assertEqual(1, server.count_shutdown_handlers())

        # ... The service provider should have been stored
        self.assertIs(service._service_provider, provider)  # noqa
 def setUp(self):
     """Constructor"""
     self.default_uri = 'file://my.sql'
     self.flow_validator = RequestFlowValidator()
     self.mock_server_set_request = mock.MagicMock()
     self.mock_server = JSONRPCServer(None, None)
     self.mock_server.set_request_handler = self.mock_server_set_request
     self.mock_workspace_service = WorkspaceService()
     self.mock_connection_service = ConnectionService()
     self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                  PG_PROVIDER_NAME, None)
     self.mock_service_provider._services[
         constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service
     self.mock_service_provider._services[
         constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
     self.mock_service_provider._is_initialized = True
     self.default_text_position = TextDocumentPosition.from_dict({
         'text_document': {
             'uri': self.default_uri
         },
         'position': {
             'line': 3,
             'character': 10
         }
     })
     self.default_text_document_id = TextDocumentIdentifier.from_dict(
         {'uri': self.default_uri})
    def setUp(self):
        """Constructor"""
        self.default_connection_key = 'server_db_user'
        self.mock_connection_service = ConnectionService()
        self.mock_server = JSONRPCServer(None, None)
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     PG_PROVIDER_NAME, None)
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails.from_data({})
        self.connection_details.server_name = 'test_host'
        self.connection_details.database_name = 'test_db'
        self.connection_details.user_name = 'user'
        self.expected_context_key = 'test_host|test_db|user'
        self.expected_connection_uri = INTELLISENSE_URI + self.expected_context_key
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create mock CompletionRefresher to avoid calls to create separate thread
        self.refresher_mock = mock.MagicMock()
        self.refresh_method_mock = mock.MagicMock()
        self.refresher_mock.refresh = self.refresh_method_mock
예제 #5
0
    def setUp(self):
        """Set up mock objects for testing the query execution service.
        Ran before each unit test.
        """
        # set up mock connection
        self.rows = [(1, 'Text 1'), (2, 'Text 2')]
        self.cursor = utils.MockCursor(self.rows)
        self.mock_pymysql_connection = utils.MockPyMySQLConnection(
            parameters={
                'host': 'test',
                'dbname': 'test',
            })
        self.connection = MockMySQLServerConnection()
        self.connection.cursor.return_value = self.cursor
        self.cursor.connection = self.connection
        self.connection_service = ConnectionService()
        self.request_context = utils.MockRequestContext()

        # setup mock query_execution_service
        self.query_execution_service = QueryExecutionService()
        self.service_provider = ServiceProvider(None, {},
                                                constants.MYSQL_PROVIDER_NAME)
        self.service_provider._services = {
            constants.CONNECTION_SERVICE_NAME: self.connection_service
        }
        self.service_provider._is_initialized = True
        self.query_execution_service._service_provider = self.service_provider

        def connection_side_effect(owner_uri: str,
                                   connection_type: ConnectionType):
            return self.connection

        self.connection_service.get_connection = mock.Mock(
            side_effect=connection_side_effect)
    def test_format_mysql_doc_range(self):
        """
        Test that the format document range codepath works as expected with a mysql doc
        """
        # set up service provider with mysql connection
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     MYSQL_PROVIDER_NAME, None)
        self.mock_service_provider._services[
            constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # If: The script file doesn't exist (there is an empty workspace)
        input_lines: List[str] = [
            'select * from t1',
            'select * from foo where id in (select id from bar);'
        ]
        input_text = '\n'.join(input_lines)
        expected_output = '\n'.join([
            'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id',
            '\t\t\t\t\tFROM bar);'
        ])

        context: RequestContext = utils.MockRequestContext()
        config = Configuration()
        config.my_sql = MySQLConfiguration()
        config.my_sql.format.keyword_case = 'upper'
        self.mock_workspace_service._configuration = config
        workspace, script_file = self._get_test_workspace(True, input_text)
        self.mock_workspace_service._workspace = workspace
        service: LanguageService = self._init_service()

        format_options = FormattingOptions()
        format_options.insert_spaces = False
        format_params = DocumentRangeFormattingParams()
        format_params.options = format_options
        format_params.text_document = self.default_text_document_id
        # add uri to valid uri set ensure request passes uri check
        # normally done in flavor change handler, but we are not testing that here
        service._valid_uri.add(format_params.text_document.uri)

        # When: I request format the 2nd line of a document
        format_params.range = Range.from_data(1, 0, 1, len(input_lines[1]))
        service.handle_doc_range_format_request(context, format_params)

        # Then:
        # ... only the 2nd line should be formatted
        context.send_response.assert_called_once()
        edits: List[TextEdit] = context.last_response_params
        self.assertTrue(len(edits) > 0)
        self.assert_range_equals(edits[0].range, format_params.range)
        self.assertEqual(edits[0].new_text, expected_output)
    def test_format_mysql_doc(self):
        """
        Test that the format document codepath works as expected with a mysql doc
        """
        # set up service provider with mysql connection
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     MYSQL_PROVIDER_NAME, None)
        self.mock_service_provider._services[
            constants.WORKSPACE_SERVICE_NAME] = self.mock_workspace_service
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # If: We have a basic string to be formatted
        input_text = 'select * from foo where id in (select id from bar);'
        # Note: sqlparse always uses '\n\ for line separator even on windows.
        # For now, respecting this behavior and leaving as-is
        expected_output = '\n'.join([
            'SELECT *', 'FROM foo', 'WHERE id IN', '\t\t\t\t(SELECT id',
            '\t\t\t\t\tFROM bar);'
        ])

        context: RequestContext = utils.MockRequestContext()
        config = Configuration()
        config.my_sql = MySQLConfiguration()
        config.my_sql.format.keyword_case = 'upper'
        self.mock_workspace_service._configuration = config
        workspace, script_file = self._get_test_workspace(True, input_text)
        self.mock_workspace_service._workspace = workspace
        service: LanguageService = self._init_service()

        format_options = FormattingOptions()
        format_options.insert_spaces = False
        format_params = DocumentFormattingParams()
        format_params.options = format_options
        format_params.text_document = self.default_text_document_id
        # add uri to valid uri set ensure request passes uri check
        # normally done in flavor change handler, but we are not testing that here
        service._valid_uri.add(format_params.text_document.uri)

        # When: I request document formatting
        service.handle_doc_format_request(context, format_params)

        # Then:
        # ... The entire document text should be formatted
        context.send_response.assert_called_once()
        edits: List[TextEdit] = context.last_response_params
        self.assertTrue(len(edits) > 0)
        self.assert_range_equals(edits[0].range,
                                 Range.from_data(0, 0, 0, len(input_text)))
        self.assertEqual(edits[0].new_text, expected_output)
예제 #8
0
def _create_server(input_stream, output_stream, server_logger, provider):
    # Create the server, but don't start it yet
    rpc_server = JSONRPCServer(input_stream, output_stream, server_logger)

    # Create the service provider and add the providers to it
    services = {
        constants.ADMIN_SERVICE_NAME: AdminService,
        constants.CAPABILITIES_SERVICE_NAME: CapabilitiesService,
        constants.CONNECTION_SERVICE_NAME: ConnectionService,
        constants.DISASTER_RECOVERY_SERVICE_NAME: DisasterRecoveryService,
        constants.LANGUAGE_SERVICE_NAME: LanguageService,
        constants.METADATA_SERVICE_NAME: MetadataService,
        constants.OBJECT_EXPLORER_NAME: ObjectExplorerService,
        constants.QUERY_EXECUTION_SERVICE_NAME: QueryExecutionService,
        constants.SCRIPTING_SERVICE_NAME: ScriptingService,
        constants.WORKSPACE_SERVICE_NAME: WorkspaceService,
        constants.EDIT_DATA_SERVICE_NAME: EditDataService,
        constants.TASK_SERVICE_NAME: TaskService
    }
    service_box = ServiceProvider(rpc_server, services, provider,
                                  server_logger)
    service_box.initialize()
    return rpc_server
예제 #9
0
    def test_registration(self):
        # Setup:
        # ... Create a mock service provider
        server: JSONRPCServer = JSONRPCServer(None, None)
        server.set_notification_handler = mock.MagicMock()
        server.set_request_handler = mock.MagicMock()
        sp: ServiceProvider = ServiceProvider(server, {}, PG_PROVIDER_NAME,
                                              utils.get_mock_logger())

        # If: I register a scripting service
        ss: ScriptingService = ScriptingService()
        ss.register(sp)

        # Then:
        # ... The service should have registered its request handlers
        server.set_request_handler.assert_called()
        server.set_notification_handler.assert_not_called()

        # ... The service provider should have been stored
        self.assertIs(ss._service_provider, sp)
예제 #10
0
    def test_register(self):
        # Setup:
        # ... Create a mock service provider
        server: JSONRPCServer = JSONRPCServer(None, None)
        server.set_notification_handler = MagicMock()
        server.set_request_handler = MagicMock()
        sp: ServiceProvider = ServiceProvider(server, {}, PG_PROVIDER_NAME,
                                              utils.get_mock_logger())

        # If: I register a workspace service
        ws: WorkspaceService = WorkspaceService()
        ws.register(sp)

        # Then:
        # ... The notifications should have been registered
        server.set_notification_handler.assert_called()
        server.set_request_handler.assert_not_called()

        # ... The service provider should have been stored
        self.assertIs(ws._service_provider, sp)
예제 #11
0
    def test_initialization(self):
        # Setup: Create a capabilities service with a mocked out service provider
        mock_server_set_request = mock.MagicMock()
        mock_server = JSONRPCServer(None, None)
        mock_server.set_request_handler = mock_server_set_request
        mock_service_provider = ServiceProvider(mock_server, {},
                                                PG_PROVIDER_NAME, None)
        service = CapabilitiesService()

        # If: I initialize the service
        service.register(mock_service_provider)

        # Then:
        # ... There should have been request handlers set
        mock_server_set_request.assert_called()

        # ... Each mock call should have an IncomingMessageConfig and a function pointer
        for mock_call in mock_server_set_request.mock_calls:
            self.assertIsInstance(mock_call[1][0],
                                  IncomingMessageConfiguration)
            self.assertTrue(callable(mock_call[1][1]))
    def test_register(self):
        # Setup:
        # ... Create a mock service provider
        server: JSONRPCServer = JSONRPCServer(None, None)
        server.set_notification_handler = mock.MagicMock()
        server.set_request_handler = mock.MagicMock()
        sp: ServiceProvider = ServiceProvider(server, {},
                                              constants.PG_PROVIDER_NAME,
                                              utils.get_mock_logger())

        # If: I register a OE service
        oe = ObjectExplorerService()
        oe.register(sp)

        # Then:
        # ... The service should have registered its request handlers
        server.set_request_handler.assert_called()
        server.set_notification_handler.assert_not_called()

        # ... The service provider should have been stored
        self.assertIs(oe._service_provider, sp)
예제 #13
0
 def setUp(self):
     service_provider = ServiceProvider(None, {}, PG_PROVIDER_NAME)
     self.object_explorer_service = ObjectExplorerService()
     self.object_explorer_service.service_provider = service_provider
     self.object_explorer_service._routing_table = PG_ROUTING_TABLE