def server_info_is_cloud_internal(self, host_suffix, is_cloud):
        """Test that the connection response handles cloud connections correctly"""
        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres'}
        connection_type = ConnectionType.DEFAULT

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': f'myserver{host_suffix}',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            response = self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri, connection_type))

        # Verify that the response's serverInfo.isCloud attribute is set correctly
        self.assertIsNotNone(response.connection_id)
        self.assertIsNotNone(response.server_info.server_version)
        self.assertEqual(response.server_info.is_cloud, is_cloud)
Esempio n. 2
0
    def test_generate_uri_missing_params(self):
        # Setup: Create the parameter sets that will be missing a param each
        params = [
            ConnectionDetails.from_data({
                'host': None,
                'dbname': TEST_DBNAME,
                'user': TEST_USER
            }),
            ConnectionDetails.from_data({
                'host': TEST_HOST,
                'dbname': None,
                'user': TEST_USER
            }),
            ConnectionDetails.from_data({
                'host': TEST_HOST,
                'dbname': TEST_DBNAME,
                'user': None
            })
        ]

        for param_set in params:
            # If: I generate a session URI from params that are missing a value
            # Then: I should get an exception
            with self.assertRaises(Exception):
                ObjectExplorerService._generate_session_uri(param_set)
Esempio n. 3
0
def _connection_details() -> Tuple[ConnectionDetails, str]:
    param = ConnectionDetails()
    param.options = {
        'host': TEST_HOST,
        'dbname': TEST_DBNAME,
        'user': TEST_USER
    }
    session_uri = ObjectExplorerService._generate_session_uri(param)
    return param, session_uri
    def test_build_connection_response(self):
        """Test that the connection response is built correctly"""
        # Set up the test with mock data
        server_name = 'testserver'
        db_name = 'testdb'
        user = '******'
        mock_connection = MockConnection({
            'host': server_name,
            'dbname': db_name,
            'user': user
        })
        connection_type = ConnectionType.EDIT
        connection_details = ConnectionDetails.from_data(opts={})
        owner_uri = 'test_uri'
        connection_info = ConnectionInfo(owner_uri, connection_details)
        connection_info._connection_map = {connection_type: mock_connection}

        # If I build a connection response for the connection
        response = pgsqltoolsservice.connection.connection_service._build_connection_response(
            connection_info, connection_type)

        # Then the response should have accurate information about the connection
        self.assertEqual(response.owner_uri, owner_uri)
        self.assertEqual(response.server_info.server_version, mock_connection.server_version)
        self.assertEqual(response.server_info.is_cloud, False)
        self.assertEqual(response.connection_summary.server_name, server_name)
        self.assertEqual(response.connection_summary.database_name, db_name)
        self.assertEqual(response.connection_summary.user_name, user)
        self.assertEqual(response.type, connection_type)
    def test_list_databases_handles_query_failure(self):
        """Test that the list databases handler returns an error if the list databases query fails for any reason"""
        # Set up the test with mock data
        mock_query_results = [('database1',), ('database2',)]
        connection_uri = 'someuri'
        mock_cursor = MockCursor(mock_query_results)
        mock_cursor.fetchall.side_effect = psycopg2.ProgrammingError('')
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            },
            cursor=mock_cursor)
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected
        # databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(mock_request_context, params)
        self.assertIsNone(mock_request_context.last_notification_method)
        self.assertIsNone(mock_request_context.last_notification_params)
        self.assertIsNone(mock_request_context.last_response_params)
        self.assertIsNotNone(mock_request_context.last_error_message)
    def test_get_connection_creates_connection(self):
        """Test that get_connection creates a new connection when none exists for the given URI and type"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.EDIT
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            })

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)) as mock_psycopg2_connect:
            # Open the connection
            self.connection_service.connect(ConnectRequestParams(connection_details, connection_uri, connection_type))

            # Get the connection
            connection = self.connection_service.get_connection(connection_uri, connection_type)
            self.assertEqual(connection, mock_connection)
            mock_psycopg2_connect.assert_called_once()
    def test_list_databases(self):
        """Test that the list databases handler correctly lists the connection's databases"""
        # Set up the test with mock data
        mock_query_results = [('database1',), ('database2',)]
        connection_uri = 'someuri'
        mock_connection = MockConnection(
            dsn_parameters={
                'host': 'myserver',
                'dbname': 'postgres',
                'user': '******'
            },
            cursor=MockCursor(mock_query_results))
        mock_request_context = utils.MockRequestContext()

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Verify that calling the listdatabases handler returns the expected databases
        params = ListDatabasesParams()
        params.owner_uri = connection_uri

        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)):
            self.connection_service.handle_list_databases(mock_request_context, params)
        expected_databases = [result[0] for result in mock_query_results]
        self.assertEqual(mock_request_context.last_response_params.database_names, expected_databases)
    def test_disconnect_all_types(self):
        """Test that the disconnect method calls close on a all open connection types when no type is given"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        connection_type_2 = ConnectionType.EDIT
        mock_connection_1 = MockConnection(dsn_parameters={
            'host': 'myserver1',
            'dbname': 'postgres1',
            'user': '******'
        })
        mock_connection_2 = MockConnection(dsn_parameters={
            'host': 'myserver2',
            'dbname': 'postgres2',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type_1, mock_connection_1)
        old_connection_info.add_connection(connection_type_2, mock_connection_2)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(old_connection_info)
        mock_connection_1.close.assert_called_once()
        mock_connection_2.close.assert_called_once()
        self.assertTrue(response)
    def test_same_options_uses_existing_connection(self):
        """Test that the connect method uses an existing connection when connecting again with the same options"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri, old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[connection_uri] = old_connection_info

        # Connect with identical options, and verify that disconnect was not called
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri': connection_uri,
            'type': connection_type,
            'connection': {
                'options': old_connection_details.options
            }
        })
        with mock.patch('psycopg2.connect', new=mock.Mock(return_value=mock_connection)) as mock_psycopg2_connect:
            response = self.connection_service.connect(params)
            mock_psycopg2_connect.assert_not_called()
        mock_connection.close.assert_not_called()
        self.assertIsNotNone(response.connection_id)
    def test_disconnect_for_invalid_connection(self):
        """Test that the disconnect method returns false when called on a connection that does not exist"""
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type_1 = ConnectionType.DEFAULT
        mock_connection_1 = MockConnection(dsn_parameters={
            'host': 'myserver1',
            'dbname': 'postgres1',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({'abc': 123})
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type_1,
                                           mock_connection_1)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Close the connection by calling disconnect
        response = self.connection_service._close_connections(
            old_connection_info, ConnectionType.EDIT)
        mock_connection_1.close.assert_not_called()
        self.assertFalse(response)
Esempio n. 11
0
    def setUp(self):
        """Constructor"""
        self.default_connection_key = 'server_db_user'
        self.mock_connection_service = ConnectionService()
        self.mock_server = JSONRPCServer(None, None)
        self.mock_service_provider = ServiceProvider(self.mock_server, {},
                                                     None)
        self.mock_service_provider._services[
            constants.CONNECTION_SERVICE_NAME] = self.mock_connection_service
        self.mock_service_provider._is_initialized = True

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails.from_data({})
        self.connection_details.server_name = 'test_host'
        self.connection_details.database_name = 'test_db'
        self.connection_details.user_name = 'user'
        self.expected_context_key = 'test_host|test_db|user'
        self.expected_connection_uri = INTELLISENSE_URI + self.expected_context_key
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create mock CompletionRefresher to avoid calls to create separate thread
        self.refresher_mock = mock.MagicMock()
        self.refresh_method_mock = mock.MagicMock()
        self.refresher_mock.refresh = self.refresh_method_mock
Esempio n. 12
0
    def test_routing_target_get_nodes_not_empty(self):
        # Setup: Create mock node generator and folder node list
        node1 = NodeInfo()
        node2 = NodeInfo()
        node_generator = mock.MagicMock(return_value=[node1, node2])
        folder_list = [
            routing.Folder('Folder1', 'fp1'),
            routing.Folder('Folder2', 'fp2')
        ]

        # If: I ask for nodes for a routing target
        rt = routing.RoutingTarget(folder_list, node_generator)
        current_path = '/'
        match_params = {}
        session = ObjectExplorerSession('session_id', ConnectionDetails())
        output = rt.get_nodes(False, current_path, session, match_params)

        # Then:
        # ... I should get back a list of nodes
        self.assertIsInstance(output, list)
        for node in output:
            self.assertIsInstance(node, NodeInfo)
        self.assertEqual(len(output), 4)
        self.assertEqual(output[0].node_type, 'Folder')
        self.assertEqual(output[1].node_type, 'Folder')
        self.assertIs(output[2], node1)
        self.assertIs(output[3], node2)

        # ... The node generator should have been called
        node_generator.assert_called_once_with(False, current_path, session,
                                               match_params)
Esempio n. 13
0
 def test_routing_invalid_path(self):
     # If: Ask to route a path without a route
     # Then: I should get an exception
     with self.assertRaises(ValueError):
         routing.route_request(
             False, ObjectExplorerSession('session_id',
                                          ConnectionDetails()),
             '!/invalid!')
Esempio n. 14
0
    def test_routing_target_get_nodes_empty(self):
        # If: I ask for nodes for an empty routing target
        rt = routing.RoutingTarget(None, None)
        output = rt.get_nodes(
            False, '/', ObjectExplorerSession('session_id',
                                              ConnectionDetails()), {})

        # Then: The results should be empty
        self.assertListEqual(output, [])
Esempio n. 15
0
    def test_routing_match(self):
        # If: Ask to route a request that is valid
        output = routing.route_request(
            False, ObjectExplorerSession('session_id', ConnectionDetails()),
            '/')

        # Then: The output should be a list of nodes
        self.assertIsInstance(output, list)
        for node in output:
            self.assertIsInstance(node, NodeInfo)
Esempio n. 16
0
    def test_on_connect_sends_notification(self):
        """
        Test that the service sends an intellisense ready notification after handling an on connect notification from the connection service.
        This is a slightly more end-to-end test that verifies calling through to the queue layer
        """
        # If: I create a new language service
        service: LanguageService = self._init_service_with_flow_validator()
        conn_info = ConnectionInfo(
            'file://msuri.sql',
            ConnectionDetails.from_data({
                'host': None,
                'dbname': 'TEST_DBNAME',
                'user': '******'
            }))

        connect_result = mock.MagicMock()
        connect_result.error_message = None
        self.mock_connection_service.get_connection = mock.Mock(
            return_value=mock.MagicMock())
        self.mock_connection_service.connect = mock.MagicMock(
            return_value=connect_result)

        def validate_success_notification(response: IntelliSenseReadyParams):
            self.assertEqual(response.owner_uri, conn_info.owner_uri)

        # When: I notify of a connection complete for a given URI
        self.flow_validator.add_expected_notification(
            IntelliSenseReadyParams, INTELLISENSE_READY_NOTIFICATION,
            validate_success_notification)

        refresher_mock = mock.MagicMock()
        refresh_method_mock = mock.MagicMock()
        refresher_mock.refresh = refresh_method_mock
        patch_path = 'pgsqltoolsservice.language.operations_queue.CompletionRefresher'
        with mock.patch(patch_path) as refresher_patch:
            refresher_patch.return_value = refresher_mock
            task: threading.Thread = service.on_connect(conn_info)
            # And when refresh is "complete"
            refresh_method_mock.assert_called_once()
            callback = refresh_method_mock.call_args[0][0]
            self.assertIsNotNone(callback)
            callback(None)
            # Wait for task to return
            task.join()

        # Then:
        # an intellisense ready notification should be sent for that URI
        self.flow_validator.validate()
        # ... and the scriptparseinfo should be created
        info: ScriptParseInfo = service.get_script_parse_info(
            conn_info.owner_uri)
        self.assertIsNotNone(info)
        # ... and the info should have the connection key set
        self.assertEqual(info.connection_key,
                         OperationsQueue.create_key(conn_info))
    def run_on_connect_callback(self, conn_type: ConnectionType,
                                expect_callback: bool) -> None:
        """Inner function for callback tests that verifies expected behavior given different connection types"""
        callbacks = [MagicMock(), MagicMock()]
        for callback in callbacks:
            self.connection_service.register_on_connect_callback(callback)

        # Set up the parameters for the connection
        connection_uri = 'someuri'
        connection_details = ConnectionDetails()
        connection_details.options = {
            'user': '******',
            'password': '******',
            'host': f'myserver',
            'dbname': 'postgres'
        }
        connection_type = conn_type

        # Set up the mock connection for psycopg2's connect method to return
        mock_connection = MockConnection(dsn_parameters={
            'host': f'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Set up the connection service and call its connect method with the
        # supported options
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(
                ConnectRequestParams(connection_details, connection_uri,
                                     connection_type))
            self.connection_service.get_connection(connection_uri, conn_type)
        # ... The mock config change callbacks should have been called
        for callback in callbacks:
            if (expect_callback):
                callback.assert_called_once()
                # Verify call args match expected
                callargs: ConnectionInfo = callback.call_args[0][0]
                self.assertEqual(callargs.owner_uri, connection_uri)
            else:
                callback.assert_not_called()
Esempio n. 18
0
    def test_handle_close_session_incomplete_params(self):
        # If: I close an OE session for with missing params
        # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all
        #       scenarios in a different test
        rc = RequestFlowValidator().add_expected_error(
            type(None), RequestFlowValidator.basic_error_validation)
        params = ConnectionDetails.from_data({})
        self.oe._handle_close_session_request(rc.request_context, params)

        # Then:
        # ... I should get an error response
        rc.validate()
    def test_get_connection_info(self):
        """Test that get_connection_info returns the ConnectionInfo object corresponding to a connection"""
        # Set up the test with mock data
        connection_uri = 'someuri'

        # Insert a ConnectionInfo object into the connection service's map
        connection_details = ConnectionDetails.from_data({})
        connection_info = ConnectionInfo(connection_uri, connection_details)
        self.connection_service.owner_to_connection_map[connection_uri] = connection_info

        # Get the connection info
        actual_connection_info = self.connection_service.get_connection_info(connection_uri)
        self.assertIs(actual_connection_info, connection_info)
Esempio n. 20
0
    def handle_change_database_request(self, request_context: RequestContext,
                                       params: ChangeDatabaseRequestParams) -> bool:
        """change database of an existing connection or create a new connection
        with default database from input"""
        connection_info: ConnectionInfo = self.get_connection_info(params.owner_uri)

        if connection_info is None:
            return False

        connection_info_params: Dict[str, str] = connection_info.details.options.copy()
        connection_info_params["dbname"] = params.new_database
        connection_details: ConnectionDetails = ConnectionDetails.from_data(connection_info_params)

        connection_request_params: ConnectRequestParams = ConnectRequestParams(connection_details, params.owner_uri, ConnectionType.DEFAULT)
        self.handle_connect_request(request_context, connection_request_params)
    def _create_connection(self, session: ObjectExplorerSession, database_name: str) -> Optional[psycopg2.extensions.connection]:
        conn_service = self._service_provider[utils.constants.CONNECTION_SERVICE_NAME]

        options = session.connection_details.options.copy()
        options['dbname'] = database_name
        conn_details = ConnectionDetails.from_data(options)

        key_uri = session.id + database_name
        connect_request = ConnectRequestParams(conn_details, key_uri, ConnectionType.OBJECT_EXLPORER)
        connect_result = conn_service.connect(connect_request)
        if connect_result.error_message is not None:
            raise RuntimeError(connect_result.error_message)

        connection = conn_service.get_connection(key_uri, ConnectionType.OBJECT_EXLPORER)
        return connection
    def test_changing_options_disconnects_existing_connection(self):
        """
        Test that the connect method disconnects an existing connection when trying to open the same connection with
        different options
        """
        # Set up the test with mock data
        connection_uri = 'someuri'
        connection_type = ConnectionType.DEFAULT
        mock_connection = MockConnection(dsn_parameters={
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******'
        })

        # Insert a ConnectionInfo object into the connection service's map
        old_connection_details = ConnectionDetails.from_data({
            'host': 'myserver',
            'dbname': 'postgres',
            'user': '******',
            'abc': 123
        })
        old_connection_info = ConnectionInfo(connection_uri,
                                             old_connection_details)
        old_connection_info.add_connection(connection_type, mock_connection)
        self.connection_service.owner_to_connection_map[
            connection_uri] = old_connection_info

        # Create a different request with the same owner uri
        params: ConnectRequestParams = ConnectRequestParams.from_dict({
            'ownerUri':
            connection_uri,
            'type':
            connection_type,
            'connection': {
                'options': {
                    'host': 'myserver',
                    'dbname': 'postgres',
                    'user': '******',
                    'abc': 234
                }
            }
        })

        # Connect with different options, and verify that disconnect was called
        with mock.patch('psycopg2.connect',
                        new=mock.Mock(return_value=mock_connection)):
            self.connection_service.connect(params)
        mock_connection.close.assert_called_once()
    def _handle_create_session_request(self, request_context: RequestContext, params: ConnectionDetails) -> None:
        """Handle a create object explorer session request"""
        # Step 1: Create the session
        try:
            # Make sure we have the appropriate session params
            utils.validate.is_not_none('params', params)

            if params.database_name is None or params.database_name == '':
                params.database_name = self._service_provider[utils.constants.WORKSPACE_SERVICE_NAME].configuration.pgsql.default_database

            # Generate the session ID and create/store the session
            session_id = self._generate_session_uri(params)
            session: ObjectExplorerSession = ObjectExplorerSession(session_id, params)

            # Add the session to session map in a lock to prevent race conditions between check and add
            with self._session_lock:
                if session_id in self._session_map:
                    # Removed the exception for now. But we need to investigate why we would get this
                    if self._service_provider.logger is not None:
                        self._service_provider.logger.error(f'Object explorer session for {session_id} already exists!')
                    request_context.send_response(False)
                    return

                self._session_map[session_id] = session

            # Respond that the session was created
            response = CreateSessionResponse(session_id)
            request_context.send_response(response)

        except Exception as e:
            message = f'Failed to create OE session: {str(e)}'
            if self._service_provider.logger is not None:
                self._service_provider.logger.error(message)
            request_context.send_error(message)
            return

        # Step 2: Connect the session and lookup the root node asynchronously
        try:
            session.init_task = threading.Thread(target=self._initialize_session, args=(request_context, session))
            session.init_task.daemon = True
            session.init_task.start()
        except Exception as e:
            # TODO: Localize
            self._session_created_error(request_context, session, f'Failed to start OE init task: {str(e)}')
Esempio n. 24
0
    def test_handle_create_session_incomplete_params(self):
        # Setup: Create an OE service
        oe = ObjectExplorerService()
        oe._service_provider = utils.get_mock_service_provider({})

        # If: I create an OE session for with missing params
        # NOTE: We only need to get the generate uri method to throw, we make sure it throws in all
        #       scenarios in a different test
        rc = RequestFlowValidator().add_expected_error(
            type(None), RequestFlowValidator.basic_error_validation)
        params = ConnectionDetails.from_data({})
        oe._handle_create_session_request(rc.request_context, params)

        # Then:
        # ... I should get an error response
        rc.validate()

        # ... A session should not have been created
        self.assertDictEqual(oe._session_map, {})
Esempio n. 25
0
    def setUp(self):
        """Set up the tests with a disaster recovery service and connection service with mock connection info"""
        self.disaster_recovery_service = DisasterRecoveryService()
        self.connection_service = ConnectionService()
        self.task_service = TaskService()
        self.disaster_recovery_service._service_provider = utils.get_mock_service_provider(
            {
                constants.CONNECTION_SERVICE_NAME: self.connection_service,
                constants.TASK_SERVICE_NAME: self.task_service
            })

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails()
        self.host = 'test_host'
        self.dbname = 'test_db'
        self.username = '******'
        self.connection_details.options = {
            'host': self.host,
            'dbname': self.dbname,
            'user': self.username,
            'port': 5432
        }
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create backup parameters for the tests
        self.request_context = utils.MockRequestContext()
        self.backup_path = 'mock/path/test.sql'
        self.backup_type = 'sql'
        self.data_only = False
        self.no_owner = True
        self.schema = 'test_schema'
        self.backup_params = BackupParams.from_dict({
            'ownerUri': self.test_uri,
            'backupInfo': {
                'type': self.backup_type,
                'path': self.backup_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.restore_path = 'mock/path/test.dump'
        self.restore_params = RestoreParams.from_dict({
            'ownerUri': self.test_uri,
            'options': {
                'path': self.restore_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.pg_dump_exe = 'pg_dump'
        self.pg_restore_exe = 'pg_restore'

        # Create the mock task for the tests
        self.mock_action = mock.Mock()
        self.mock_task = Task(None, None, None, None, None,
                              self.request_context, self.mock_action)
        self.mock_task.start = mock.Mock()