def __init__(self, config, backup_name, stagger, mode, temp_dir, parallel_snapshots, parallel_uploads):
     self.id = uuid.uuid4()
     # TODO expose the argument below (Note that min(1000, <number_of_hosts>) will be used)
     self.orchestration_snapshots = Orchestration(config, parallel_snapshots)
     self.orchestration_uploads = Orchestration(config, parallel_uploads)
     self.config = config
     self.backup_name = backup_name
     self.stagger = stagger
     self.mode = mode
     self.temp_dir = temp_dir
     self.work_dir = self.temp_dir / 'medusa-job-{id}'.format(id=self.id)
     self.hosts = {}
     self.cassandra = Cassandra(config)
     self.snapshot_tag = '{}{}'.format(self.cassandra.SNAPSHOT_PREFIX, self.backup_name)
     fqdn_resolver = medusa.config.evaluate_boolean(self.config.cassandra.resolve_ip_addresses)
     self.fqdn_resolver = HostnameResolver(fqdn_resolver)
    def test_pssh_without_sudo(self):
        """Ensure that Parallel SSH honors configuration when we don't want to use sudo in commands"""
        conf = self.config
        conf['cassandra']['use_sudo'] = 'False'
        medusa_conf = self._build_medusa_config(conf)
        orchestration_no_sudo = Orchestration(medusa_conf)

        output = [
            HostOutputMock(host=host, exit_code=exit_code)
            for host, exit_code in self.hosts.items()
        ]
        self.mock_pssh.run_command.return_value = output
        assert orchestration_no_sudo.pssh_run(
            list(self.hosts.keys()),
            'fake command',
            ssh_client=self.fake_ssh_client_factory)

        self.mock_pssh.run_command.assert_called_with('fake command',
                                                      host_args=None,
                                                      sudo=False)
Example #3
0
    def __init__(self,
                 cluster_backup,
                 config,
                 temp_dir,
                 host_list,
                 seed_target,
                 keep_auth,
                 verify,
                 parallel_restores,
                 keyspaces=None,
                 tables=None,
                 bypass_checks=False,
                 use_sstableloader=False,
                 version_target=None):

        self.id = uuid.uuid4()
        self.ringmap = None
        self.cluster_backup = cluster_backup
        self.session_provider = None
        self.orchestration = Orchestration(config, parallel_restores)
        self.config = config
        self.host_list = host_list
        self.seed_target = seed_target
        self.keep_auth = keep_auth
        self.verify = verify
        self.in_place = None
        self.temp_dir = temp_dir  # temporary files
        self.work_dir = self.temp_dir / 'medusa-job-{id}'.format(id=self.id)
        self.host_map = {
        }  # Map of backup host/target host for the restore process
        self.keyspaces = keyspaces if keyspaces else {}
        self.tables = tables if tables else {}
        self.bypass_checks = bypass_checks
        self.use_sstableloader = use_sstableloader
        self.pssh_pool_size = parallel_restores
        self.cassandra = Cassandra(config)
        fqdn_resolver = medusa.utils.evaluate_boolean(
            self.config.cassandra.resolve_ip_addresses)
        self.fqdn_resolver = HostnameResolver(fqdn_resolver)
        self._version_target = version_target
class BackupJob(object):
    def __init__(self, config, backup_name, seed_target, stagger,
                 enable_md5_checks, mode, temp_dir, parallel_snapshots,
                 parallel_uploads):
        self.id = uuid.uuid4()
        # TODO expose the argument below (Note that min(1000, <number_of_hosts>) will be used)
        self.orchestration_snapshots = Orchestration(config,
                                                     parallel_snapshots)
        self.orchestration_uploads = Orchestration(config, parallel_uploads)
        self.config = config
        self.backup_name = backup_name
        self.stagger = stagger
        self.seed_target = seed_target
        self.enable_md5_checks = enable_md5_checks
        self.mode = mode
        self.temp_dir = temp_dir
        self.work_dir = self.temp_dir / 'medusa-job-{id}'.format(id=self.id)
        self.hosts = {}
        self.cassandra = Cassandra(config)
        self.snapshot_tag = '{}{}'.format(self.cassandra.SNAPSHOT_PREFIX,
                                          self.backup_name)
        fqdn_resolver = medusa.config.evaluate_boolean(
            self.config.cassandra.resolve_ip_addresses)
        self.fqdn_resolver = HostnameResolver(fqdn_resolver)

    def execute(self):
        # Two step: Take snapshot everywhere, then upload the backups to the external storage

        # Getting the list of Cassandra nodes.
        seed_target = self.seed_target if self.seed_target is not None else self.config.storage.fqdn
        session_provider = CqlSessionProvider([seed_target],
                                              self.config.cassandra)
        with session_provider.new_session() as session:
            tokenmap = session.tokenmap()
            self.hosts = [host for host in tokenmap.keys()]

        # First let's take a snapshot on all nodes at once
        # Here we will use parallelism of min(number of nodes, parallel_snapshots)
        logging.info('Creating snapshots on all nodes')
        self._create_snapshots()

        # Second
        logging.info('Uploading snapshots from nodes to external storage')
        self._upload_backup()

    def _create_snapshots(self):
        # Run snapshot in parallel on all nodes,
        create_snapshot_command = ' '.join(
            self.cassandra.create_snapshot_command(self.backup_name))
        pssh_run_success = self.orchestration_snapshots.\
            pssh_run(self.hosts,
                     create_snapshot_command,
                     hosts_variables={})
        if not pssh_run_success:
            # we could implement a retry.
            err_msg = 'Some nodes failed to create the snapshot.'
            logging.error(err_msg)
            raise Exception(err_msg)

        logging.info('A snapshot {} was created on all nodes.'.format(
            self.snapshot_tag))

    def _upload_backup(self):
        backup_command = self._build_backup_cmd()
        # Run upload in parallel or sequentially according to parallel_uploads defined by the user
        pssh_run_success = self.orchestration_uploads.pssh_run(
            self.hosts, backup_command, hosts_variables={})
        if not pssh_run_success:
            # we could implement a retry.
            err_msg = 'Some nodes failed to upload the backup.'
            logging.error(err_msg)
            raise Exception(err_msg)

        logging.info('A new backup {} was created on all nodes.'.format(
            self.backup_name))

    def _build_backup_cmd(self):
        stagger_option = '--in-stagger {}'.format(
            self.stagger) if self.stagger else ''
        enable_md5_checks_option = '--enable-md5-checks' if self.enable_md5_checks else ''

        # Use %s placeholders in the below command to have them replaced by pssh using per host command substitution
        command = 'mkdir -p {work}; cd {work} && medusa-wrapper {sudo} medusa {config} -vvv backup-node ' \
                  '--backup-name {backup_name} {stagger} {enable_md5_checks} --mode {mode}' \
            .format(work=self.work_dir,
                    sudo='sudo' if medusa.utils.evaluate_boolean(self.config.cassandra.use_sudo) else '',
                    config=f'--config-file {self.config.file_path}' if self.config.file_path else '',
                    backup_name=self.backup_name,
                    stagger=stagger_option,
                    enable_md5_checks=enable_md5_checks_option,
                    mode=self.mode)

        logging.debug(
            'Running backup on all nodes with the following command {}'.format(
                command))

        return command
class RestoreJob(object):
    def __init__(self, cluster_backup, config, temp_dir, host_list, seed_target, keep_auth, verify,
                 parallel_restores, keyspaces={}, tables={}, bypass_checks=False, use_sstableloader=False):
        self.id = uuid.uuid4()
        self.ringmap = None
        self.cluster_backup = cluster_backup
        self.session_provider = None
        self.orchestration = Orchestration(config, parallel_restores)
        self.config = config
        self.host_list = host_list
        self.seed_target = seed_target
        self.keep_auth = keep_auth
        self.verify = verify
        self.in_place = None
        self.temp_dir = temp_dir  # temporary files
        self.work_dir = self.temp_dir / 'medusa-job-{id}'.format(id=self.id)
        self.host_map = {}  # Map of backup host/target host for the restore process
        self.keyspaces = keyspaces
        self.tables = tables
        self.bypass_checks = bypass_checks
        self.use_sstableloader = use_sstableloader
        self.pssh_pool_size = parallel_restores
        self.cassandra = Cassandra(config)
        fqdn_resolver = medusa.utils.evaluate_boolean(self.config.cassandra.resolve_ip_addresses)
        self.fqdn_resolver = HostnameResolver(fqdn_resolver)

    def execute(self):
        logging.info('Ensuring the backup is found and is complete')
        if not self.cluster_backup.is_complete():
            raise Exception('Backup is not complete')

        # CASE 1 : We're restoring using a seed target. Source/target mapping will be built based on tokenmap.
        if self.seed_target is not None:
            self.session_provider = CqlSessionProvider([self.seed_target],
                                                       self.config.cassandra)

            with self.session_provider.new_session() as session:
                self._populate_ringmap(self.cluster_backup.tokenmap, session.tokenmap())

        # CASE 2 : We're restoring a backup on a different cluster
        if self.host_list is not None:
            logging.info('Restore will happen on new hardware')
            self.in_place = False
            self._populate_hostmap()
            logging.info('Starting Restore on all the nodes in this list: {}'.format(self.host_list))

        self._restore_data()

    def _validate_ringmap(self, tokenmap, target_tokenmap):
        for host, ring_item in target_tokenmap.items():
            if not ring_item.get('is_up'):
                raise Exception('Target {host} is not up!'.format(host=host))
        if len(target_tokenmap) != len(tokenmap):
            return False
        return True

    def _populate_ringmap(self, tokenmap, target_tokenmap):

        def _tokens_from_ringitem(ringitem):
            return ','.join(map(str, ringitem['tokens']))

        def _token_counts_per_host(tokenmap):
            for host, ringitem in tokenmap.items():
                return len(ringitem['tokens'])

        def _hosts_from_tokenmap(tokenmap):
            hosts = set()
            for host, ringitem in tokenmap.items():
                hosts.add(host)
            return hosts

        def _chunk(my_list, nb_chunks):
            groups = []
            for i in range(nb_chunks):
                groups.append([])
            for i in range(len(my_list)):
                groups[i % nb_chunks].append(my_list[i])
            return groups

        target_tokens = {}
        backup_tokens = {}
        topology_matches = self._validate_ringmap(tokenmap, target_tokenmap)
        self.in_place = self._is_restore_in_place(tokenmap, target_tokenmap)
        if self.in_place:
            logging.info("Restoring on the same cluster that was the backup was taken on (in place fashion)")
            self.keep_auth = False
        else:
            logging.info("Restoring on a different cluster than the backup one (remote fashion)")
            if self.keep_auth:
                logging.info('system_auth keyspace will be left untouched on the target nodes')
            else:
                # ops might not be aware of the underlying behavior towards auth. Let's ask what to do...
                really_keep_auth = None
                while (really_keep_auth != 'Y' and really_keep_auth != 'n') and not self.bypass_checks:
                    really_keep_auth = input('Do you want to skip restoring the system_auth keyspace and keep the'
                                             + ' credentials of the target cluster? (Y/n)')
                self.keep_auth = True if really_keep_auth == 'Y' else False

        if topology_matches:
            target_tokens = {_tokens_from_ringitem(ringitem): host for host, ringitem in target_tokenmap.items()}
            backup_tokens = {_tokens_from_ringitem(ringitem): host for host, ringitem in tokenmap.items()}

            target_tokens_per_host = _token_counts_per_host(tokenmap)
            backup_tokens_per_host = _token_counts_per_host(target_tokenmap)

            # we must have the same number of tokens per host in both vnode and normal clusters
            if target_tokens_per_host != backup_tokens_per_host:
                logging.info('Source/target rings have different number of tokens per node: {}/{}'.format(
                    backup_tokens_per_host,
                    target_tokens_per_host
                ))
                topology_matches = False

            # if not using vnodes, the tokens must match exactly
            if backup_tokens_per_host == 1 and target_tokens.keys() != backup_tokens.keys():
                extras = target_tokens.keys() ^ backup_tokens.keys()
                logging.info('Tokenmap is differently distributed. Extra items: {}'.format(extras))
                topology_matches = False

        if topology_matches:
            # We can associate each restore node with exactly one backup node
            backup_ringmap = collections.defaultdict(list)
            target_ringmap = collections.defaultdict(list)
            for token, host in backup_tokens.items():
                backup_ringmap[token].append(host)
            for token, host in target_tokens.items():
                target_ringmap[token].append(host)

            self.ringmap = backup_ringmap
            i = 0
            for token, hosts in backup_ringmap.items():
                # take the node that has the same token list or pick the one with the same position in the map.
                restore_host = target_ringmap.get(token, list(target_ringmap.values())[i])[0]
                isSeed = True if self.fqdn_resolver.resolve_fqdn(restore_host) in self._get_seeds_fqdn() else False
                self.host_map[restore_host] = {'source': [hosts[0]], 'seed': isSeed}
                i += 1
        else:
            # Topologies are different between backup and restore clusters. Using the sstableloader for restore.
            self.use_sstableloader = True
            backup_hosts = _hosts_from_tokenmap(tokenmap)
            restore_hosts = list(_hosts_from_tokenmap(target_tokenmap))
            if len(backup_hosts) >= len(restore_hosts):
                grouped_backups = _chunk(list(backup_hosts), len(restore_hosts))
            else:
                grouped_backups = _chunk(list(backup_hosts), len(backup_hosts))
            for i in range(min([len(grouped_backups), len(restore_hosts)])):
                # associate one restore host with several backups as we don't have the same number of nodes.
                self.host_map[restore_hosts[i]] = {'source': grouped_backups[i], 'seed': False}

    def _is_restore_in_place(self, backup_tokenmap, target_tokenmap):
        # If at least one node is part of both tokenmaps, then we're restoring in place
        # Otherwise we're restoring a remote cluster
        return len(set(backup_tokenmap.keys()) & set(target_tokenmap.keys())) > 0

    def _get_seeds_fqdn(self):
        seeds = list()
        for seed in self.cassandra.seeds:
            seeds.append(self.fqdn_resolver.resolve_fqdn(seed))
        return seeds

    def _populate_hostmap(self):
        with open(self.host_list, 'r') as f:
            for line in f.readlines():
                seed, target, source = line.replace('\n', '').split(self.config.storage.host_file_separator)
                # in python, bool('False') evaluates to True. Need to test the membership as below
                self.host_map[self.fqdn_resolver.resolve_fqdn(target.strip())] \
                    = {'source': [self.fqdn_resolver.resolve_fqdn(source.strip())], 'seed': seed in ['True']}

    def _restore_data(self):
        # create workdir on each target host
        # Later: distribute a credential
        # construct command for each target host
        # invoke `nohup medusa-wrapper #{command}` on each target host
        # wait for exit on each
        logging.info('Starting cluster restore...')
        logging.info('Working directory for this execution: {}'.format(self.work_dir))
        for target, sources in self.host_map.items():
            logging.info('About to restore on {} using {} as backup source'.format(target, sources))

        logging.info('This will delete all data on the target nodes and replace it with backup {}.'
                     .format(self.cluster_backup.name))

        proceed = None
        while (proceed != 'Y' and proceed != 'n') and not self.bypass_checks:
            proceed = input('Are you sure you want to proceed? (Y/n)')

        if proceed == 'n':
            err_msg = 'Restore manually cancelled'
            logging.error(err_msg)
            raise Exception(err_msg)

        # work out which nodes are seeds in the target cluster
        target_seeds = [t for t, s in self.host_map.items() if s['seed']]
        logging.info("target seeds : {}".format(target_seeds))
        # work out which nodes are seeds in the target cluster
        target_hosts = [host for host in self.host_map.keys()]
        logging.info("target hosts : {}".format(target_hosts))

        if self.use_sstableloader is False:
            # stop all target nodes
            logging.info('Stopping Cassandra on all nodes currently up')

            # Generate a Job ID for this run
            job_id = str(uuid.uuid4())
            logging.debug('Job id is: {}'.format(job_id))
            # Define command to run
            command = self.config.cassandra.stop_cmd
            logging.debug('Command to run is: {}'.format(command))

            self.orchestration.pssh_run(target_hosts, command, hosts_variables={})

        else:
            # we're using the sstableloader, which will require to (re)create the schema and empty the tables
            logging.info("Restoring schema on the target cluster")
            self._restore_schema()

        # trigger restores everywhere at once
        # pass in seed info so that non-seeds can wait for seeds before starting
        # seeds, naturally, don't wait for anything

        # Generate a Job ID for this run
        hosts_variables = []
        for target, source in [(t, s['source']) for t, s in self.host_map.items()]:
            logging.info('Restoring data on {}...'.format(target))
            seeds = '' if target in target_seeds or len(target_seeds) == 0 \
                    else '--seeds {}'.format(','.join(target_seeds))
            hosts_variables.append((','.join(source), seeds))

        command = self._build_restore_cmd()
        pssh_run_success = self.orchestration.pssh_run(target_hosts,
                                                       command,
                                                       hosts_variables=hosts_variables)

        if not pssh_run_success:
            # we could implement a retry.
            err_msg = 'Some nodes failed to restore. Exiting'
            logging.error(err_msg)
            raise Exception(err_msg)

        logging.info('Restore process is complete. The cluster should be up shortly.')

        if self.verify:
            verify_restore(target_hosts, self.config)

    def _build_restore_cmd(self):
        in_place_option = '--in-place' if self.in_place else '--remote'
        keep_auth_option = '--keep-auth' if self.keep_auth else ''
        keyspace_options = expand_repeatable_option('keyspace', self.keyspaces)
        table_options = expand_repeatable_option('table', self.tables)
        # We explicitly set --no-verify since we are doing verification here in this module
        # from the control node
        verify_option = '--no-verify'

        # %s placeholders in the below command will get replaced by pssh using per host command substitution
        command = 'mkdir -p {work}; cd {work} && medusa-wrapper sudo medusa --fqdn=%s -vvv restore-node ' \
                  '{in_place} {keep_auth} %s {verify} --backup-name {backup} --temp-dir {temp_dir} ' \
                  '{use_sstableloader} {keyspaces} {tables}' \
            .format(work=self.work_dir,
                    in_place=in_place_option,
                    keep_auth=keep_auth_option,
                    verify=verify_option,
                    backup=self.cluster_backup.name,
                    temp_dir=self.temp_dir,
                    use_sstableloader='--use-sstableloader' if self.use_sstableloader is True else '',
                    keyspaces=keyspace_options,
                    tables=table_options)

        logging.debug('Preparing to restore on all nodes with the following command {}'.format(command))

        return command

    def _restore_schema(self):
        schema = parse_schema(self.cluster_backup.schema)
        with self.session_provider.new_session() as session:
            for keyspace in schema.keys():
                if keyspace.startswith("system"):
                    continue
                else:
                    self._create_or_recreate_schema_objects(session, keyspace, schema[keyspace])

    def _create_or_recreate_schema_objects(self, session, keyspace, keyspace_schema):
        logging.info("(Re)creating schema for keyspace {}".format(keyspace))
        if (keyspace not in session.cluster.metadata.keyspaces):
            # Keyspace doesn't exist on the target cluster. Got to create it and all the tables as well.
            session.execute(keyspace_schema['create_statement'])
        for mv in keyspace_schema['materialized_views']:
            # MVs need to be dropped before we drop the tables
            logging.debug("Dropping MV {}.{}".format(keyspace, mv[0]))
            session.execute("DROP MATERIALIZED VIEW IF EXISTS {}.{}".format(keyspace, mv[0]))
        for table in keyspace_schema['tables'].items():
            logging.debug("Dropping table {}.{}".format(keyspace, table[0]))
            session.execute("DROP TABLE IF EXISTS {}.{}".format(keyspace, table[0]))
        for udt in keyspace_schema['udt'].items():
            # then custom types as they can be used in tables
            session.execute("DROP TYPE IF EXISTS {}.{}".format(keyspace, udt[0]))
            # Then we create the missing ones
            session.execute(udt[1])
        for table in keyspace_schema['tables'].items():
            logging.debug("Creating table {}.{}".format(keyspace, table[0]))
            # Create the tables
            session.execute(table[1])
        for index in keyspace_schema['indices'].items():
            # indices were dropped with their base tables
            logging.debug("Creating index {}.{}".format(keyspace, index[0]))
            session.execute(index[1])
        for mv in keyspace_schema['materialized_views']:
            # Base tables are created now, we can create the MVs
            logging.debug("Creating MV {}.{}".format(keyspace, mv[0]))
            session.execute(mv[1])
 def setUp(self):
     self.hosts = {'127.0.0.1': ExitCode.SUCCESS}
     self.config = self._build_config_parser()
     self.medusa_config = self._build_medusa_config(self.config)
     self.orchestration = Orchestration(self.medusa_config)
     self.mock_pssh = create_autospec(ParallelSSHClient)
class OrchestrationTest(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def setUp(self):
        self.hosts = {'127.0.0.1': ExitCode.SUCCESS}
        self.config = self._build_config_parser()
        self.medusa_config = self._build_medusa_config(self.config)
        self.orchestration = Orchestration(self.medusa_config)
        self.mock_pssh = create_autospec(ParallelSSHClient)

    def fake_ssh_client_factory(self, *args, **kwargs):
        return self.mock_pssh

    @staticmethod
    def _build_config_parser():
        """Build and return a mutable config"""

        config = configparser.ConfigParser(interpolation=None)
        config['cassandra'] = {
            'use_sudo': 'True',
        }
        config['ssh'] = {
            'username': '',
            'key_file': '',
            'port': '22',
            'cert_file': ''
        }
        return config

    @staticmethod
    def _build_medusa_config(config):
        return MedusaConfig(
            file_path=None,
            storage=None,
            monitoring={},
            cassandra=_namedtuple_from_dict(CassandraConfig,
                                            config['cassandra']),
            ssh=_namedtuple_from_dict(SSHConfig, config['ssh']),
            checks=None,
            logging=None,
            grpc=None,
            kubernetes=None,
        )

    def test_pssh_with_sudo(self):
        """Ensure that Parallel SSH honors configuration when we want to use sudo in commands"""
        output = [
            HostOutputMock(host=host, exit_code=exit_code)
            for host, exit_code in self.hosts.items()
        ]
        self.mock_pssh.run_command.return_value = output
        assert self.orchestration.pssh_run(
            list(self.hosts.keys()),
            'fake command',
            ssh_client=self.fake_ssh_client_factory)
        self.mock_pssh.run_command.assert_called_with('fake command',
                                                      host_args=None,
                                                      sudo=True)

    def test_pssh_without_sudo(self):
        """Ensure that Parallel SSH honors configuration when we don't want to use sudo in commands"""
        conf = self.config
        conf['cassandra']['use_sudo'] = 'False'
        medusa_conf = self._build_medusa_config(conf)
        orchestration_no_sudo = Orchestration(medusa_conf)

        output = [
            HostOutputMock(host=host, exit_code=exit_code)
            for host, exit_code in self.hosts.items()
        ]
        self.mock_pssh.run_command.return_value = output
        assert orchestration_no_sudo.pssh_run(
            list(self.hosts.keys()),
            'fake command',
            ssh_client=self.fake_ssh_client_factory)

        self.mock_pssh.run_command.assert_called_with('fake command',
                                                      host_args=None,
                                                      sudo=False)

    def test_pssh_run_failure(self):
        """Ensure that Parallel SSH detects a failed command on a host"""
        hosts = {
            '127.0.0.1': ExitCode.SUCCESS,
            '127.0.0.2': ExitCode.ERROR,
            '127.0.0.3': ExitCode.SUCCESS,
        }
        output = [
            HostOutputMock(host=host, exit_code=exit_code)
            for host, exit_code in hosts.items()
        ]
        self.mock_pssh.run_command.return_value = output
        assert not self.orchestration.pssh_run(
            list(self.hosts.keys()),
            'fake command',
            ssh_client=self.fake_ssh_client_factory)
        self.mock_pssh.run_command.assert_called_with('fake command',
                                                      host_args=None,
                                                      sudo=True)