예제 #1
0
 def handle_restore_request(self, request_context: RequestContext,
                            params: RestoreParams) -> None:
     """
     Respond to restore/restore requests by performing a restore
     """
     connection_info: ConnectionInfo = self._service_provider[
         constants.CONNECTION_SERVICE_NAME].get_connection_info(
             params.owner_uri)
     if connection_info is None:
         request_context.send_error(
             'No connection corresponding to the given owner URI'
         )  # TODO: Localize
         return
     provider: str = self._service_provider.provider
     host = connection_info.details.options['host']
     database = connection_info.details.options['dbname']
     task = Task(
         'Restore',
         f'Host: {host}, Database: {database}',
         provider,
         host,
         database,
         request_context,  # TODO: Localize
         functools.partial(_perform_restore, connection_info, params))
     self._service_provider[constants.TASK_SERVICE_NAME].start_task(task)
     request_context.send_response({})
예제 #2
0
def _perform_backup_restore(connection_info: ConnectionInfo,
                            process_args: List[str], options: Dict[str, Any],
                            task: Task):
    """Call out to pg_dump or pg_restore using the arguments given and additional arguments built from the given options dict"""
    for option, value in options.items():
        # Don't add the option to the arguments if it is not set
        if value is None or value is False:
            continue
        # Replace underscores with dashes in the option name
        key_name = inflection.dasherize(option)
        if value is True:
            # The option is a boolean flag, so just add the option
            process_args.insert(-1, f'--{key_name}')
        else:
            # The option has a value, so add the flag with its value
            process_args.insert(-1, f'--{key_name}={value}')
    with task.cancellation_lock:
        if task.canceled:
            return TaskResult(TaskStatus.CANCELED)
        try:
            os.putenv('PGPASSWORD',
                      connection_info.details.options.get('password') or '')
            dump_restore_process = subprocess.Popen(process_args,
                                                    stderr=subprocess.PIPE)
            task.on_cancel = dump_restore_process.terminate
            _, stderr = dump_restore_process.communicate()
        except subprocess.SubprocessError as err:
            dump_restore_process.kill()
            return TaskResult(TaskStatus.FAILED, str(err))
    if dump_restore_process.returncode != 0:
        return TaskResult(TaskStatus.FAILED, str(stderr, 'utf-8'))
    return TaskResult(TaskStatus.SUCCEEDED)
예제 #3
0
 def setUp(self):
     self.task_service = TaskService()
     self.service_provider = ServiceProviderMock(
         {constants.TASK_SERVICE_NAME: self.task_service})
     self.request_validator = RequestFlowValidator()
     self.mock_task_1 = Task(None, None, None, None, None,
                             self.request_validator.request_context,
                             mock.Mock(), mock.Mock())
     self.request_validator.add_expected_notification(
         TaskInfo, 'tasks/newtaskcreated')
     self.mock_task_1.start = mock.Mock()
     self.mock_task_2 = Task(None, None, None, None, None,
                             self.request_validator.request_context,
                             mock.Mock(), mock.Mock())
     self.request_validator.add_expected_notification(
         TaskInfo, 'tasks/newtaskcreated')
     self.mock_task_2.start = mock.Mock()
예제 #4
0
 def create_task(self) -> Task:
     """Create a task using the test's data"""
     return Task(self.task_name, self.task_description, self.task_provider,
                 self.server_name, self.database_name, self.request_context,
                 self.mock_action)
예제 #5
0
    def setUp(self):
        """Set up the tests with a disaster recovery service and connection service with mock connection info"""
        self.disaster_recovery_service = DisasterRecoveryService()
        self.connection_service = ConnectionService()
        self.task_service = TaskService()
        self.disaster_recovery_service._service_provider = utils.get_mock_service_provider(
            {
                constants.CONNECTION_SERVICE_NAME: self.connection_service,
                constants.TASK_SERVICE_NAME: self.task_service
            })

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails()
        self.host = 'test_host'
        self.dbname = 'test_db'
        self.username = '******'
        self.connection_details.options = {
            'host': self.host,
            'dbname': self.dbname,
            'user': self.username,
            'port': 5432
        }
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create backup parameters for the tests
        self.request_context = utils.MockRequestContext()
        self.backup_path = 'mock/path/test.sql'
        self.backup_type = 'sql'
        self.data_only = False
        self.no_owner = True
        self.schema = 'test_schema'
        self.backup_params = BackupParams.from_dict({
            'ownerUri': self.test_uri,
            'backupInfo': {
                'type': self.backup_type,
                'path': self.backup_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.restore_path = 'mock/path/test.dump'
        self.restore_params = RestoreParams.from_dict({
            'ownerUri': self.test_uri,
            'options': {
                'path': self.restore_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.pg_dump_exe = 'pg_dump'
        self.pg_restore_exe = 'pg_restore'

        # Create the mock task for the tests
        self.mock_action = mock.Mock()
        self.mock_task = Task(None, None, None, None, None,
                              self.request_context, self.mock_action)
        self.mock_task.start = mock.Mock()
예제 #6
0
class TestDisasterRecoveryService(unittest.TestCase):
    """Methods for testing the disaster recovery service"""
    def setUp(self):
        """Set up the tests with a disaster recovery service and connection service with mock connection info"""
        self.disaster_recovery_service = DisasterRecoveryService()
        self.connection_service = ConnectionService()
        self.task_service = TaskService()
        self.disaster_recovery_service._service_provider = utils.get_mock_service_provider(
            {
                constants.CONNECTION_SERVICE_NAME: self.connection_service,
                constants.TASK_SERVICE_NAME: self.task_service
            })

        # Create connection information for use in the tests
        self.connection_details = ConnectionDetails()
        self.host = 'test_host'
        self.dbname = 'test_db'
        self.username = '******'
        self.connection_details.options = {
            'host': self.host,
            'dbname': self.dbname,
            'user': self.username,
            'port': 5432
        }
        self.test_uri = 'test_uri'
        self.connection_info = ConnectionInfo(self.test_uri,
                                              self.connection_details)

        # Create backup parameters for the tests
        self.request_context = utils.MockRequestContext()
        self.backup_path = 'mock/path/test.sql'
        self.backup_type = 'sql'
        self.data_only = False
        self.no_owner = True
        self.schema = 'test_schema'
        self.backup_params = BackupParams.from_dict({
            'ownerUri': self.test_uri,
            'backupInfo': {
                'type': self.backup_type,
                'path': self.backup_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.restore_path = 'mock/path/test.dump'
        self.restore_params = RestoreParams.from_dict({
            'ownerUri': self.test_uri,
            'options': {
                'path': self.restore_path,
                'data_only': self.data_only,
                'no_owner': self.no_owner,
                'schema': self.schema
            }
        })
        self.pg_dump_exe = 'pg_dump'
        self.pg_restore_exe = 'pg_restore'

        # Create the mock task for the tests
        self.mock_action = mock.Mock()
        self.mock_task = Task(None, None, None, None, None,
                              self.request_context, self.mock_action)
        self.mock_task.start = mock.Mock()

    def test_get_pg_exe_path_local_linux(self):
        """Test the get_pg_exe_path function for linux when the service is running from source code"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/linux', ('10',
                                                                    '11'), ()),
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/linux/10',
             ('bin', 'lib'), ()),
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/linux/11',
             ('bin', 'lib'), ())
        ]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code directly from source
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/ossdbtoolsservice/ossdbtoolsservice_main.py'
                    )

                    # If I get the executable path on Linux
                    sys.platform = 'linux'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (10, 0))
                    # Then the path uses the linux directory and does not have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/ossdbtoolsservice/pg_exes/linux/10/bin/pg_dump'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_local_mac(self):
        """Test the get_pg_exe_path function for mac when the service is running from source code"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/mac', ('9.5',
                                                                  '9.6'), ()),
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/mac/9.5',
             ('bin', 'lib'), ()),
            ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/mac/9.6',
             ('bin', 'lib'), ())
        ]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code directly from source
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/ossdbtoolsservice/ossdbtoolsservice_main.py'
                    )

                    # If I get the executable path on Mac
                    sys.platform = 'darwin'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (9, 6))
                    # Then the path uses the mac directory and does not have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/ossdbtoolsservice/pg_exes/mac/9.6/bin/pg_dump'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_local_win(self):
        """Test the get_pg_exe_path function for windows when the service is running from source code"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win',
                         ('11', '12'), ()),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/11',
                         (), ('pg_dump', 'pg_restore')),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/12',
                         (), ('pg_dump', 'pg_restore'))]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code directly from source
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/ossdbtoolsservice/ossdbtoolsservice_main.py'
                    )

                    # If I get the executable path on Windows
                    sys.platform = 'win32'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (11, 0))
                    # Then the path uses the win directory and does have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/11/pg_dump.exe'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_frozen_linux(self):
        """Test the get_pg_exe_path function for linux when the service is running from a cx_freeze build"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/linux',
             ('10', '11'), ()),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/linux/10',
             ('bin', 'lib'), ()),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/linux/11',
             ('bin', 'lib'), ())
        ]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code from a cx_freeze build
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/build/ossdbtoolsservice/ossdbtoolsservice_main'
                    )

                    # If I get the executable path on Linux
                    sys.platform = 'linux'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (10, 0))
                    # Then the path uses the linux directory and does not have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/linux/10/bin/pg_dump'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_frozen_mac(self):
        """Test the get_pg_exe_path function for mac when the service is running from a cx_freeze build"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/mac',
             ('9.5', '9.6'), ()),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/mac/9.5',
             ('bin', 'lib'), ()),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/mac/9.6',
             ('bin', 'lib'), ())
        ]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code from a cx_freeze build
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/build/ossdbtoolsservice/ossdbtoolsservice_main'
                    )

                    # If I get the executable path on Mac
                    sys.platform = 'darwin'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (9, 6))
                    # Then the path uses the mac directory and does not have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/mac/9.6/bin/pg_dump'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_frozen_win(self):
        """Test the get_pg_exe_path function for windows when the service is running from a cx_freeze build"""
        # Back up these values so that the test can overwrite them
        old_arg0 = sys.argv[0]
        old_platform = sys.platform
        return_value = [
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/win',
             ('11', '12'), ()),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/win/11', (),
             ('pg_dump', 'pg_restore')),
            ('/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/win/12', (),
             ('pg_dump', 'pg_restore'))
        ]
        try:
            with mock.patch('os.walk',
                            new=mock.Mock(return_value=return_value)):
                with mock.patch('os.path.exists',
                                new=mock.Mock(return_value=True)):
                    # Override sys.argv[0] to simulate running the code from a cx_freeze build
                    sys.argv[0] = os.path.normpath(
                        '/ossdbtoolsservice/build/ossdbtoolsservice/ossdbtoolsservice_main'
                    )

                    # If I get the executable path on Windows
                    sys.platform = 'win32'
                    path = disaster_recovery_service._get_pg_exe_path(
                        self.pg_dump_exe, (11, 0))
                    # Then the path uses the win directory and does have a trailing .exe
                    self.assertEqual(
                        path,
                        os.path.normpath(
                            '/ossdbtoolsservice/build/ossdbtoolsservice/pg_exes/win/11/pg_dump.exe'
                        ))
        finally:
            sys.argv[0] = old_arg0
            sys.platform = old_platform

    def test_get_pg_exe_path_does_not_exist(self):
        """Test that the get_pg_exe_path function throws an error if the executable being searched for does not exist"""
        return_value = [('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win',
                         ('11', '12'), ()),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/11',
                         (), ('pg_dump', 'pg_restore')),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/12',
                         (), ('pg_dump', 'pg_restore'))]
        with mock.patch('os.walk', new=mock.Mock(return_value=return_value)):
            with mock.patch(
                    'os.path.exists', new=mock.Mock(
                        return_value=False)), self.assertRaises(ValueError):
                disaster_recovery_service._get_pg_exe_path(
                    'not_pg_dump', (11, 0))

    def test_get_pg_server_folder_does_not_exist(self):
        """Test that the get_pg_exe_path function throws an error if the server version folder being searched for does not exist"""
        return_value = [('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win',
                         ('11', '12'), ()),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/11',
                         (), ('pg_dump', 'pg_restore')),
                        ('/ossdbtoolsservice/ossdbtoolsservice/pg_exes/win/12',
                         (), ('pg_dump', 'pg_restore'))]
        with mock.patch('os.walk', new=mock.Mock(
                return_value=return_value)), self.assertRaises(ValueError):
            disaster_recovery_service._get_pg_exe_path('pg_dump', (9, 6))

    def test_perform_backup(self):
        """Test that the perform_backup method passes the correct parameters to pg_dump"""
        exe_name = self.pg_dump_exe
        test_method = disaster_recovery_service._perform_backup
        test_params = self.backup_params
        expected_args = [
            f'--file={self.backup_path}', '--format=p',
            f'--dbname={self.dbname}', f'--host={self.host}',
            f'--username={self.username}', '--no-owner',
            f'--schema={self.schema}', f'--port=5432'
        ]
        self._test_perform_backup_restore_internal(exe_name, test_method,
                                                   test_params, expected_args)

    def test_perform_restore(self):
        """Test that the perform_restore method passes the correct parameters to pg_restore"""
        exe_name = self.pg_restore_exe
        test_method = disaster_recovery_service._perform_restore
        test_params = self.restore_params
        expected_args = [
            f'{self.restore_path}', f'--dbname={self.dbname}',
            f'--host={self.host}', f'--username={self.username}', '--no-owner',
            f'--schema={self.schema}', f'--port=5432'
        ]
        self._test_perform_backup_restore_internal(exe_name, test_method,
                                                   test_params, expected_args)

    def _test_perform_backup_restore_internal(self, exe_name: str,
                                              test_method: Callable,
                                              test_params,
                                              expected_args: List[str]):
        mock_pg_path = f'mock/{exe_name}'
        mock_process = mock.Mock()
        mock_process.returncode = 0
        mock_process.communicate = mock.Mock(return_value=(b'', b''))
        mockConnection = pg_utils.MockPGServerConnection(None)
        with mock.patch('ossdbtoolsservice.disaster_recovery.disaster_recovery_service._get_pg_exe_path', new=mock.Mock(return_value=mock_pg_path)) \
                as mock_get_path, mock.patch('subprocess.Popen', new=mock.Mock(return_value=mock_process)) as mock_popen:
            # If I perform a backup/restore
            with mock.patch(
                    'ossdbtoolsservice.connection.ConnectionInfo.get_connection',
                    new=mock.Mock(return_value=mockConnection)):
                task_result = test_method(self.connection_info, test_params,
                                          self.mock_task)
            # Then the code got the path of the executable
            mock_get_path.assert_called_once_with(
                exe_name, mockConnection.server_version)
            # And ran the executable as a subprocess
            mock_popen.assert_called_once()
            # And then called communicate on the process
            mock_process.communicate.assert_called_once_with()
            # And the arguments for the subprocess.Popen call were the expected values
            actual_args = mock_popen.call_args[0][0]
            self.assertEqual(actual_args[0], mock_pg_path)
            pg_exe_flags = actual_args[1:]
            for expected_arg in expected_args:
                self.assertIn(expected_arg, pg_exe_flags)
            self.assertEqual(len(expected_args), len(pg_exe_flags))
            # And the task returns a successful result
            self.assertIs(task_result.status, TaskStatus.SUCCEEDED)

    def test_perform_backup_fails(self):
        """Test that the perform_backup method handles failures by recording pg_dump's stderr output and marking the task failed"""
        self._test_perform_backup_restore_fails_internal(
            self.pg_dump_exe, disaster_recovery_service._perform_backup,
            self.backup_params)

    def test_perform_restore_fails(self):
        """Test that the perform_restore method handles failures by recording pg_dump's stderr output and marking the task failed"""
        self._test_perform_backup_restore_fails_internal(
            self.pg_restore_exe, disaster_recovery_service._perform_restore,
            self.restore_params)

    def _test_perform_backup_restore_fails_internal(self, exe_name: str,
                                                    test_method: Callable,
                                                    test_params):
        mock_pg_path = f'mock/{exe_name}'
        mock_process = mock.Mock()
        mock_process.returncode = 1
        test_error_message = b'test error message'
        mock_process.communicate = mock.Mock(return_value=(b'',
                                                           test_error_message))
        with mock.patch(
                'ossdbtoolsservice.disaster_recovery.disaster_recovery_service._get_pg_exe_path',
                new=mock.Mock(return_value=mock_pg_path)), mock.patch(
                    'subprocess.Popen',
                    new=mock.Mock(return_value=mock_process)):
            with mock.patch(
                    'ossdbtoolsservice.connection.ConnectionInfo.get_connection',
                    new=mock.Mock(
                        return_value=pg_utils.MockPGServerConnection(None))):
                # If I perform a backup/restore that fails
                task_result = test_method(self.connection_info, test_params,
                                          self.mock_task)
                # Then the task returns a failed result
                self.assertIs(task_result.status, TaskStatus.FAILED)
                # And the task contains the error message from stderr
                self.assertEqual(task_result.error_message,
                                 str(test_error_message, 'utf-8'))

    def test_perform_backup_no_exe(self):
        """Test that the perform_backup task fails when the pg_dump exe is not found"""
        self._test_perform_backup_restore_no_exe_internal(
            disaster_recovery_service._perform_backup, self.backup_params)

    def test_perform_restore_no_exe(self):
        """Test that the perform_restore task fails when the pg_restore exe is not found"""
        self._test_perform_backup_restore_no_exe_internal(
            disaster_recovery_service._perform_restore, self.restore_params)

    def _test_perform_backup_restore_no_exe_internal(self,
                                                     test_method: Callable,
                                                     test_params):
        mockConnection = pg_utils.MockPGServerConnection(None)
        with mock.patch('os.path.exists',
                        new=mock.Mock(return_value=False)), mock.patch(
                            'subprocess.Popen') as mock_popen:
            with mock.patch(
                    'ossdbtoolsservice.connection.ConnectionInfo.get_connection',
                    new=mock.Mock(return_value=mockConnection)):
                # If I perform a restore when the pg_restore executable cannot be found
                task_result = test_method(self.connection_info, test_params,
                                          mock.Mock())
                # Then the task fails and does try to kick off a new process
                self.assertIs(task_result.status, TaskStatus.FAILED)
                mock_popen.assert_not_called()

    def test_handle_backup_request(self):
        """Test that the backup request handler responds properly and kicks off a task to perform the backup"""
        self._test_handle_backup_restore_internal(
            self.disaster_recovery_service.handle_backup_request,
            disaster_recovery_service._perform_backup, self.backup_params)

    def test_handle_restore_request(self):
        """Test that the restore request handler responds properly and kicks off a task to perform the restore"""
        self._test_handle_backup_restore_internal(
            self.disaster_recovery_service.handle_restore_request,
            disaster_recovery_service._perform_restore, self.restore_params)

    def _test_handle_backup_restore_internal(self, test_handler: Callable,
                                             test_method: Callable,
                                             test_params):
        # Set up the connection service to return the test's connection information
        self.connection_service.owner_to_connection_map[
            self.test_uri] = self.connection_info
        # Set up a mock task so that the restore code does not actually run in a separate thread
        with mock.patch('ossdbtoolsservice.disaster_recovery.disaster_recovery_service.Task', new=mock.Mock(return_value=self.mock_task)) \
                as mock_task_constructor, mock.patch('functools.partial', new=mock.Mock(return_value=self.mock_action)) as mock_partial:
            # When I call the backup/restore request handler
            test_handler(self.request_context, test_params)
            # Then a mock task is created and started
            mock_task_constructor.assert_called_once()
            self.mock_task.start.assert_called_once()
            # And the mock task was initialized with the expected parameters
            parameters = mock_task_constructor.call_args[0]
            self.assertEqual(parameters[2], constants.PG_PROVIDER_NAME)
            self.assertEqual(parameters[3], self.host)
            self.assertEqual(parameters[4], self.dbname)
            self.assertIs(parameters[6], self.mock_action)
            mock_partial.assert_called_once_with(test_method,
                                                 self.connection_info,
                                                 test_params)
            # And the handler sends an empty response to indicate success
            self.assertEqual(self.request_context.last_response_params, {})

    def test_handle_backup_request_no_connection(self):
        """Test that the backup request handler responds with an error if there is no connection for the given owner URI"""
        self._test_handle_backup_restore_request_no_connection(
            self.disaster_recovery_service.handle_backup_request,
            self.backup_params)

    def test_handle_restore_request_no_connection(self):
        """Test that the restore request handler responds with an error if there is no connection for the given owner URI"""
        self._test_handle_backup_restore_request_no_connection(
            self.disaster_recovery_service.handle_restore_request,
            self.restore_params)

    def _test_handle_backup_restore_request_no_connection(
            self, test_handler: Callable, test_params):
        # Set up a mock task so that the restore code does not actually run in a separate thread
        with mock.patch('ossdbtoolsservice.disaster_recovery.disaster_recovery_service.Task', new=mock.Mock(return_value=self.mock_task)) \
                as mock_task_constructor, mock.patch('functools.partial', new=mock.Mock(return_value=self.mock_action)):
            # If I call the request handler and there is no connection corresponding to the given owner URI
            test_handler(self.request_context, test_params)
            # Then a mock task is not created
            mock_task_constructor.assert_not_called()
            # And the handler sends an error response to indicate failure
            self.assertIsNotNone(self.request_context.last_error_message)

    def test_canceled_task_does_not_spawn_process(self):
        """Test that the pg_dump/pg_restore process is not created if the task has been canceled"""
        # Set up the task to be canceled
        self.mock_task.canceled = True
        mockConnection = pg_utils.MockPGServerConnection(None)
        with mock.patch('subprocess.Popen', new=mock.Mock()) as mock_popen:
            with mock.patch(
                    'ossdbtoolsservice.connection.ConnectionInfo.get_connection',
                    new=mock.Mock(return_value=mockConnection)):
                # If I try to perform a backup/restore for a canceled task
                disaster_recovery_service._perform_backup_restore(
                    self.connection_info, [], {}, self.mock_task)

                # Then the process was not created
                mock_popen.assert_not_called()

    def test_cancel_backup(self):
        """Test that backups can be canceled by calling terminate on the pg_dump process"""

        # Set up the task to be canceled when communicate would normally be called

        def cancel_function(*_, **__):
            self.mock_task.cancel()
            return mock.DEFAULT

        self.mock_task.status = TaskStatus.IN_PROGRESS
        mock_process = mock.Mock()
        mock_process.communicate = mock.Mock(return_value=(None, None),
                                             side_effect=cancel_function)
        mock_process.terminate = mock.Mock()
        mock_process.returncode = 0
        with mock.patch('subprocess.Popen',
                        new=mock.Mock(return_value=mock_process)):
            # If I perform a backup/restore that kicks off the subprocess and then I cancel the task
            disaster_recovery_service._perform_backup_restore(
                self.connection_info, [], {}, self.mock_task)

            # Then the backup/restore process was terminated
            mock_process.terminate.assert_called_once()
예제 #7
0
 def start_task(self, task: Task) -> None:
     """Register a task so that it can be listed and canceled, then start it"""
     self._task_map[task.id] = task
     task.start()