Esempio n. 1
0
    def test_regional_cred_map(self):
        aws_clients = ClientFactory(
            regional_cred_map={'ap-east-1': {
                'profile_name': 'blah'
            }})
        self.assertEquals(aws_clients._credential_sets['ap-east-1'],
                          [None, None, None, 'blah'])
        self.assertEquals(aws_clients._credential_sets['default'],
                          [None, None, None, None])

        ClientFactory._create_client.reset_mock()
        aws_clients.get("test_service", region='ap-east-1')
        ClientFactory._create_client.assert_called_once_with(
            'ap-east-1', 'ap-east-1', 'test_service', 'default_sig_version')
Esempio n. 2
0
    def test___init__(self):
        # Mock the put_credential_set method that is called during init
        with mock.patch.object(ClientFactory,
                               'put_credential_set',
                               return_value=None):
            aws_clients = ClientFactory(None)

            msg = "clients should be empty"
            self.assertEqual({"default": {}}, aws_clients._clients, msg)

            msg = "lock should be an instance of Lock"
            self.assertEqual(type(Lock()), type(aws_clients._lock), msg)

            msg = "logger should exist and level set to ERROR"
            logger = logging.getLogger()
            logger.setLevel(logging.ERROR)
            self.assertEqual(logger, aws_clients.logger, msg)

            msg = "logger should exist and level set to WARNING"
            logger.setLevel(logging.WARNING)
            aws_clients = ClientFactory(logger)
            self.assertEqual(logger, aws_clients.logger, msg)
Esempio n. 3
0
def client_factory_instance():
    with mock.patch.object(ClientFactory, '__init__', return_value=None):
        aws_clients = ClientFactory(None)
    aws_clients._credential_sets = {'default': [None, None, None, None]}
    aws_clients.logger = logging.getLogger()
    aws_clients._clients = {"default": {}}
    aws_clients._lock = Lock()
    return aws_clients
Esempio n. 4
0
 def __init__(self,
              path_to_templates,
              user_config_file=None,
              use_upstream_mappings=True,
              client_factory=None):
     if client_factory:
         AMIUpdater.client_factory = client_factory
     else:
         AMIUpdater.client_factory = ClientFactory()
     self.all_regions = False
     if use_upstream_mappings:
         Config.load(self.upstream_config_file, configtype='Upstream')
     if user_config_file:
         Config.load(user_config_file, configtype='User')
     self._template_path = path_to_templates
Esempio n. 5
0
    def __init__(
        self,
        input_path,
        target_bucket_name,
        source_bucket_name=None,
        target_key_prefix=None,
        source_key_prefix=None,
        output_directory=None,
        rewrite_mode=OBJECT_REWRITE_MODE,
        verbose=False,
        dry_run=False
    ):
        """
        Construct an Alchemist object.

        :param input_path: Directory path to the root of the assets
        :param target_bucket_name: Target S3 bucket to use as replacement and to upload to
        :param source_bucket_name: Source S3 bucket to search for replacement
        :param target_key_prefix: Target S3 key prefix to prepend to all object (including an ending forward slash '/')
        :param output_directory: Directory to save rewritten assets to
        :param rewrite_mode: Mode for rewriting like CFNAlchemist.OBJECT_REWRITE_MODE or CFNAlchemist.BASIC_REWRITE_MODE
        :param verbose: Set to True to log debug messages
        :param dry_run: Set to True to perform a dry run
        """
        # create logger
        self.logger = logging.getLogger('alchemist')
        self.logger.setLevel(logging.INFO)
        # create console handler and set level to debug
        self.ch = logging.StreamHandler()
        self.ch.setLevel(logging.INFO)
        # create formatter
        self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        # add formatter to ch
        self.ch.setFormatter(self.formatter)
        # add ch to logger
        self.logger.addHandler(self.ch)

        # Constants
        self._TEMPLATE_EXT = ['.template', '.json', '.yaml', '.yml']
        self._GIT_EXT = ['.git', '.gitmodules', '.gitignore', '.gitattributes']
        self._EXCLUDED_DIRS = ['.git', 'ci', '.idea', '.vs']

        # properties
        self._boto_clients = ClientFactory(logger=self.logger)
        self._auth_mode = None
        self._aws_profile = None
        self._aws_access_key_id = None
        self._aws_secret_access_key = None
        self._aws_session_token = None

        # properties with setters/getters
        self._input_path = None
        self._target_bucket_name = None
        self._target_key_prefix = None
        self._output_directory = None
        self._rewrite_mode = self.OBJECT_REWRITE_MODE
        self._excluded_prefixes = None
        self._verbose = False
        self._dry_run = False
        self._prod_bucket_name = 'aws-quickstart'
        self._prod_key_prefix = None
        self._default_region = 'us-east-1'
        self._file_list = None

        # initialize
        self.set_input_path(input_path)
        self.set_prod_bucket_name(source_bucket_name)
        self.set_target_bucket_name(target_bucket_name)
        self.set_target_key_prefix(target_key_prefix)
        self.set_output_directory(output_directory)
        if rewrite_mode not in [self.OBJECT_REWRITE_MODE, self.BASIC_REWRITE_MODE]:
            self.logger.error("Invalid rewrite_mode.")
        else:
            self.set_rewrite_mode(rewrite_mode)
        self.set_verbose(verbose)
        self.set_dry_run(dry_run)
        self.set_prod_key_prefix(source_key_prefix)

        return
Esempio n. 6
0
class CFNAlchemist(object):
    OBJECT_REWRITE_MODE = 10
    BASIC_REWRITE_MODE = 20

    def __init__(
        self,
        input_path,
        target_bucket_name,
        source_bucket_name=None,
        target_key_prefix=None,
        source_key_prefix=None,
        output_directory=None,
        rewrite_mode=OBJECT_REWRITE_MODE,
        verbose=False,
        dry_run=False
    ):
        """
        Construct an Alchemist object.

        :param input_path: Directory path to the root of the assets
        :param target_bucket_name: Target S3 bucket to use as replacement and to upload to
        :param source_bucket_name: Source S3 bucket to search for replacement
        :param target_key_prefix: Target S3 key prefix to prepend to all object (including an ending forward slash '/')
        :param output_directory: Directory to save rewritten assets to
        :param rewrite_mode: Mode for rewriting like CFNAlchemist.OBJECT_REWRITE_MODE or CFNAlchemist.BASIC_REWRITE_MODE
        :param verbose: Set to True to log debug messages
        :param dry_run: Set to True to perform a dry run
        """
        # create logger
        self.logger = logging.getLogger('alchemist')
        self.logger.setLevel(logging.INFO)
        # create console handler and set level to debug
        self.ch = logging.StreamHandler()
        self.ch.setLevel(logging.INFO)
        # create formatter
        self.formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        # add formatter to ch
        self.ch.setFormatter(self.formatter)
        # add ch to logger
        self.logger.addHandler(self.ch)

        # Constants
        self._TEMPLATE_EXT = ['.template', '.json', '.yaml', '.yml']
        self._GIT_EXT = ['.git', '.gitmodules', '.gitignore', '.gitattributes']
        self._EXCLUDED_DIRS = ['.git', 'ci', '.idea', '.vs']

        # properties
        self._boto_clients = ClientFactory(logger=self.logger)
        self._auth_mode = None
        self._aws_profile = None
        self._aws_access_key_id = None
        self._aws_secret_access_key = None
        self._aws_session_token = None

        # properties with setters/getters
        self._input_path = None
        self._target_bucket_name = None
        self._target_key_prefix = None
        self._output_directory = None
        self._rewrite_mode = self.OBJECT_REWRITE_MODE
        self._excluded_prefixes = None
        self._verbose = False
        self._dry_run = False
        self._prod_bucket_name = 'aws-quickstart'
        self._prod_key_prefix = None
        self._default_region = 'us-east-1'
        self._file_list = None

        # initialize
        self.set_input_path(input_path)
        self.set_prod_bucket_name(source_bucket_name)
        self.set_target_bucket_name(target_bucket_name)
        self.set_target_key_prefix(target_key_prefix)
        self.set_output_directory(output_directory)
        if rewrite_mode not in [self.OBJECT_REWRITE_MODE, self.BASIC_REWRITE_MODE]:
            self.logger.error("Invalid rewrite_mode.")
        else:
            self.set_rewrite_mode(rewrite_mode)
        self.set_verbose(verbose)
        self.set_dry_run(dry_run)
        self.set_prod_key_prefix(source_key_prefix)

        return

    def set_verbose(self, verbose):
        self._verbose = verbose
        self.logger.setLevel(logging.DEBUG if self._verbose else logging.INFO)
        self.ch.setLevel(logging.DEBUG if self._verbose else logging.INFO)

    def get_verbose(self):
        return self._verbose

    def set_dry_run(self, dry_run):
        self._dry_run = dry_run

    def get_dry_run(self):
        return self._dry_run

    def set_input_path(self, input_path):
        self._input_path = input_path

    def get_input_path(self):
        return self._input_path

    def set_prod_key_prefix(self, source_key_prefix):
        if source_key_prefix is not None:
            self._prod_key_prefix = source_key_prefix.strip('/') + '/'

    def get_prod_key_prefix(self):
        return self._prod_key_prefix

    def set_target_bucket_name(self, target_bucket_name):
        self._target_bucket_name = target_bucket_name

    def get_target_bucket_name(self):
        return self._target_bucket_name

    def set_target_key_prefix(self, target_key_prefix):
        if target_key_prefix is not None:
            self._target_key_prefix = target_key_prefix.strip('/') + '/'
            self._set_excluded_key_prefixes()

    def get_target_key_prefix(self):
        return self._target_key_prefix

    def set_output_directory(self, output_directory):
        self._output_directory = output_directory

    def get_output_directory(self):
        return self._output_directory

    def set_rewrite_mode(self, rewrite_type):
        self._rewrite_mode = rewrite_type

    def get_rewrite_mode(self):
        return self._rewrite_mode

    def set_prod_bucket_name(self, prod_bucket_name):
        if prod_bucket_name is not None:
            self._prod_bucket_name = prod_bucket_name

    def get_prod_bucket_name(self):
        return self._prod_bucket_name

    def set_default_region(self, region):
        self._default_region = region

    def get_default_region(self):
        return self._default_region

    def _set_excluded_key_prefixes(self):
        self._excluded_prefixes = [
            '{}doc/'.format(self._target_key_prefix),
            '{}pics/'.format(self._target_key_prefix),
            '{}media/'.format(self._target_key_prefix),
            '{}downloads/'.format(self._target_key_prefix),
            '{}installers/'.format(self._target_key_prefix)
        ]

    def _get_excluded_key_prefixes(self):
        return self._excluded_prefixes

    def upload_only(self):
        """
        This function uploads all assets to the target S3 bucket name using the target S3 key prefix for each object.
          A comparison of checksums is done for all object as well to avoid reuploading files that have not changed (this
          checksum comparison is only effective on non-multi part uploaded files).
        """
        if self._target_key_prefix is None:
            raise TaskCatException('target_key_prefix cannot be None')
        # TODO: FIGURE OUT BOTO SESSION HANDLING DETAILS CURRENTLY USING ClientFactory's get_session from utils.py
        '''
        # Use a profile
        if args.profile:
            boto_session = boto3.Session(profile_name=args.profile)
            s3_resource = boto_session.resource('s3')
        # Use explicit credentials
        elif args.access_key_id and args.secret_access_key:
            boto_session = boto3.Session(aws_access_key_id=args.access_key_id,
                                         aws_secret_access_key=args.secret_access_key)
            s3_resource = boto_session.resource('s3')
        # Attempt to use IAM role from instance profile
        else:
            boto_session = boto3.Session()
            s3_resource = boto_session.resource('s3')
        '''
        boto_session = self._boto_clients.get_session(
            credential_set='alchemist',
            region=self.get_default_region()
        )
        s3_resource = boto_session.resource('s3')
        upload_bucket = s3_resource.Bucket(self._target_bucket_name)

        self.logger.info("Gathering remote S3 bucket keys {}*".format(self._target_key_prefix))
        remote_key_dict = {}
        for obj in upload_bucket.objects.filter(Prefix='{}'.format(self._target_key_prefix)):
            if any(x not in obj.key for x in self._get_excluded_key_prefixes()):
                remote_key_dict[obj.key] = obj
        self.logger.debug(remote_key_dict.keys())

        # Gather file list
        # NOTE: We only use the output directory if it's been set (that is, a rewrite was expected to have happened to
        #       an output directory. We ensure that is not the case when parsing the args, but care must be taken
        #       when initializing all the properties of this class. If it's only an upload that's meant to happen
        #       without a previous rewrite, then output directory should never be set.
        self.logger.info("Gathering local keys {}*".format(self._target_key_prefix))
        if self._file_list:
            file_list = self._file_list
        else:
            file_list = self._get_file_list(self._input_path)

        local_key_dict = {}
        for current_file in file_list:
            local_key_dict[os.path.join(self._target_key_prefix, current_file.replace(self._input_path, '', 1).lstrip('\/')).replace('\\', '/')] = \
                os.path.join(self._output_directory if self._output_directory and not self._dry_run else self._input_path, current_file.replace(self._input_path, '', 1).lstrip('\/'))
        self.logger.debug(local_key_dict.keys())

        remote_to_local_diff = list(set(remote_key_dict.keys()) - set(local_key_dict.keys()))
        self.logger.info("Keys in remote S3 bucket but not in local:")
        self.logger.info(remote_to_local_diff)

        local_to_remote_diff = list(set(local_key_dict.keys()) - set(remote_key_dict.keys()))
        self.logger.info("Keys in local but not in remote S3 bucket:")
        self.logger.info(local_to_remote_diff)

        self.logger.info("Syncing objects to S3 bucket [{}]".format(self._target_bucket_name))
        for _key in local_key_dict.keys():
            if _key in remote_key_dict:
                self.logger.debug("File [{0}] exists in S3 bucket [{1}]. Verifying MD5 checksum for difference.".format(_key, self._target_bucket_name))
                s3_hash = remote_key_dict[_key].e_tag.strip('"')
                local_hash = hashlib.md5(open(local_key_dict[_key], 'rb').read()).hexdigest()
                self.logger.debug("S3 MD5 checksum (etag) [{0}]=>[{1}]".format(s3_hash, remote_key_dict[_key]))
                self.logger.debug("Local MD5 checksum     [{0}]=>[{1}]".format(local_hash, local_key_dict[_key]))
                if s3_hash != local_hash:
                    if self._dry_run:
                        self.logger.info("[WHAT IF DRY RUN]: UPDATE [{0}]".format(_key))
                    else:
                        self.logger.info("UPDATE [{0}]".format(_key))
                        s3_resource.Object(self._target_bucket_name, _key).upload_file(local_key_dict[_key])
                else:
                    self.logger.debug("MD5 checksums are the same. Skipping [{}]".format(_key))
            else:
                if self._dry_run:
                    self.logger.info("[WHAT IF DRY RUN]: CREATE [{0}]".format(_key))
                else:
                    self.logger.info("CREATE [{0}]".format(_key))
                    # Upload local file not present in S3 bucket
                    s3_resource.Object(self._target_bucket_name, _key).upload_file(local_key_dict[_key])

        # clean up/remove remote keys that are not in local keys
        for _key in remote_to_local_diff:
            if not any(x in _key for x in self._get_excluded_key_prefixes()):
                if self._dry_run:
                    self.logger.info("[WHAT IF DRY RUN]: DELETE [{0}]".format(_key))
                else:
                    self.logger.info("DELETE [{0}]".format(_key))
                    remote_key_dict[_key].delete()

    def rewrite_only(self):
        """
        This function searches through all the files and rewrites any references of the production S3 bucket name
         to the target S3 bucket name. This is done by both things like line-by-line basic rewrites or walking the
         tree of a JSON or YAML document to find the references.
        """
        # Create file list and recurse if args._input_path is directory
        file_list = self._get_file_list(self._input_path)
        self.logger.info("Files to be worked on:")
        self.logger.info(file_list)

        # Validate output
        if self._output_directory is not None:
            CFNYAMLHandler.validate_output_dir(self._output_directory)

        self.logger.info("Production S3 bucket name that we are looking for [{}]".format(self._prod_bucket_name))
        self.logger.info("Replacement S3 bucket name that we are rewriting with [{}]".format(self._target_bucket_name))
        self.logger.info("Production S3 key prefix that we are looking for [{}]".format(self._prod_key_prefix))
        self.logger.info("Replacement S3 key prefix that we are rewriting with [{}]".format(self._target_key_prefix))

        # Rewrite files
        for current_file in file_list:
            # Determine output file
            if self._output_directory:
                if len(file_list) == 1:
                    output_file = os.path.join(self._output_directory, os.path.basename(current_file))
                else:
                    output_file = os.path.join(self._output_directory, current_file.replace(self._input_path, '', 1).lstrip('\/'))
            else:
                output_file = current_file

            # Load current file
            if self._rewrite_mode != self.BASIC_REWRITE_MODE \
                    and current_file.endswith(tuple(self._TEMPLATE_EXT)) \
                    and os.path.dirname(current_file).endswith('/templates'):
                self.logger.info("Opening file [{}]".format(current_file))
                with open(current_file, 'r', newline=None) as template:
                    template_raw_data = template.read()
                    template.close()
                template_raw_data = template_raw_data.strip()

                if template_raw_data[0] in ['{', '['] and template_raw_data[-1] in ['}', ']']:
                    self.logger.info('Detected JSON. Loading file.')
                    FILE_FORMAT = 'JSON'
                    template_data = json.load(open(current_file, 'r', newline=None), object_pairs_hook=OrderedDict)
                else:
                    self.logger.info('Detected YAML. Loading file.')
                    FILE_FORMAT = 'YAML'
                    template_data = CFNYAMLHandler.ordered_safe_load(open(current_file, 'r', newline=None), object_pairs_hook=OrderedDict)

                if FILE_FORMAT in ['JSON', 'YAML']:
                    # Iterate through every top level node.
                    # This was only added in case we need to examine only parts of the template
                    if type(template_data) in [OrderedDict, dict]:
                        for node_key in template_data.keys():
                            self.logger.debug("Working on node [{}]".format(node_key))
                            self._recurse_nodes(template_data[node_key])
                    elif type(template_data) is list:
                        self._recurse_nodes(template_data)
                    else:
                        if self._dry_run:
                            self.logger.warning("[WHAT IF DRY RUN]: [{0}] Unsupported {1} structure. Skipping but copying.".format(current_file, FILE_FORMAT))
                        else:
                            self.logger.warning("[{0}] Unsupported {1} structure. Skipping but copying.".format(current_file, FILE_FORMAT))
                            if current_file is not output_file:
                                shutil.copyfile(current_file, output_file)

                    # Write modified template
                    if self._dry_run:
                        self.logger.info("[WHAT IF DRY RUN]: Writing file [{}]".format(output_file))
                    else:
                        self.logger.info("Writing file [{}]".format(output_file))
                        CFNYAMLHandler.validate_output_dir(os.path.split(output_file)[0])
                        with open(output_file, 'w') as updated_template:
                            if FILE_FORMAT == 'JSON':
                                updated_template.write(json.dumps(template_data, indent=4, separators=(',', ': ')))
                            elif FILE_FORMAT == 'YAML':
                                updated_template.write(
                                    CFNYAMLHandler.ordered_safe_dump(template_data, indent=2, allow_unicode=True, default_flow_style=False, explicit_start=True, explicit_end=True))
                        updated_template.close()
                else:
                    if self._dry_run:
                        self.logger.warning("[WHAT IF DRY RUN]: [{}] Unsupported file format. Skipping but copying.".format(current_file))
                    else:
                        self.logger.warning("[{}] Unsupported file format. Skipping but copying.".format(current_file))
                        if current_file is not output_file:
                            shutil.copyfile(current_file, output_file)
            else:
                self.logger.info("Opening file [{}]".format(current_file))
                try:
                    with open(current_file, 'r', newline=None) as f:
                        file_data = f.readlines()

                    for index, line in enumerate(file_data):
                        file_data[index] = self._string_rewriter(line)

                    # Write modified file
                    if self._dry_run:
                        self.logger.info("[WHAT IF DRY RUN]: Writing file [{}]".format(output_file))
                    else:
                        self.logger.info("Writing file [{}]".format(output_file))
                        CFNYAMLHandler.validate_output_dir(os.path.split(output_file)[0])
                        with open(output_file, 'w') as updated_file:
                            updated_file.writelines(file_data)
                        updated_file.close()
                except UnicodeDecodeError:
                    if self._dry_run:
                        self.logger.info("[WHAT IF DRY RUN]: Ran into a (UnicodeDecodeError) problem trying to read the file [{}]. Skipping but copying.".format(current_file))
                    else:
                        self.logger.warning("Ran into a (UnicodeDecodeError) problem trying to read the file [{}]. Skipping but copying.".format(current_file))
                        self._copy_file(current_file, output_file)
                except TaskCatException:
                    raise
                except Exception as e:
                    raise e

    def rewrite_and_upload(self):
        """
        This function performs both a rewrite and upload of files by calling each respective function consecutively.
        """
        self.rewrite_only()
        self.upload_only()

    def _get_file_list(self, input_path):
        if not self._file_list:
            _file_list = []
            if os.path.isfile(input_path):
                _file_list.append(input_path)
            elif os.path.isdir(input_path):
                for root, dirs, files in os.walk(input_path):
                    for _current_file in files:
                        if not _current_file.endswith(tuple(self._GIT_EXT)):
                            _file_list.append(os.path.join(root, _current_file))
                    for directory in self._EXCLUDED_DIRS:
                        if directory in dirs:
                            dirs.remove(directory)
            else:
                raise TaskCatException("Directory/File is non-existent. Aborting.")
            self._file_list = _file_list
        return self._file_list

    def _string_rewriter(self, current_string):
        if self._prod_bucket_name in current_string:
            # If the path is s3/http/https
            if any(x in current_string for x in ['s3:', 'http:', 'https:']):
                # Make sure that it's part of the target key prefix (that is, part of this repo)
                if self._target_key_prefix in current_string:
                    self.logger.info("Rewriting [{}]".format(current_string.rstrip('\n\r')))
                    return current_string.replace(self._prod_bucket_name, self._target_bucket_name)
                # If it's not then, it's a reference that should not be touched
                else:
                    self.logger.info("NOT rewriting [{}] because it's not part of this repo".format(current_string.rstrip('\n\r')))
                    return current_string
            # Else just replace the bucket name
            else:
                self.logger.info("Rewriting [{}]".format(current_string.rstrip('\n\r')))
                return current_string.replace(self._prod_bucket_name, self._target_bucket_name)
        elif self._prod_key_prefix in current_string:
            self.logger.info("Rewriting [{}]".format(current_string.rstrip('\n\r')))
            return current_string.replace(self._prod_key_prefix, self._target_key_prefix)
        else:
            return current_string

    def _recurse_nodes(self, current_node):
        if type(current_node) in [OrderedDict, dict]:
            for key in current_node.keys():
                self.logger.debug("Key: ")
                self.logger.debug(key)
                self.logger.debug("Type: ")
                self.logger.debug(type(current_node[key]))
                self.logger.debug("Value: ")
                self.logger.debug(current_node[key])
                current_node[key] = self._recurse_nodes(current_node[key])
        elif type(current_node) is list:
            for _index, item in enumerate(current_node):
                self.logger.debug("Type: ")
                self.logger.debug(type(item))
                self.logger.debug("Value: ")
                self.logger.debug(item)
                current_node[_index] = self._recurse_nodes(item)
            return current_node
        elif type(current_node) is str:
            return self._string_rewriter(current_node)
        elif type(current_node) is bool:
            self.logger.debug("Not much we can do with booleans. Skipping.")
        elif type(current_node) in [int, float]:
            self.logger.debug("Not much we can do with numbers. Skipping.")
        elif type(current_node) in [datetime.date, datetime.time, datetime.datetime, datetime.timedelta]:
            self.logger.debug("Not much we can do with datetime. Skipping.")
        elif type(current_node) is None:
            self.logger.debug("Not much we can do with nulls. Skipping.")
        else:
            self.logger.error("Unsupported type.")
            self.logger.error("Failing Type: ")
            self.logger.error(type(current_node))
            self.logger.error("Failing Value: ")
            self.logger.error(current_node)
            raise TaskCatException("Unsupported type.")

        self.logger.debug("PARSED!")

        return current_node

    def _copy_file(self, in_file, out_file):
        CFNYAMLHandler.validate_output_dir(os.path.split(out_file)[0])
        # copy only if it's a new location for the output
        if in_file is not out_file:
            shutil.copyfile(in_file, out_file)

    def aws_api_init(self, aws_profile=None, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None):
        """
        This function reads the AWS credentials to ensure that the client has right credentials defined to successfully
        authenticate against an AWS account. It could be either profile name, access key and secret key or none.
        :param aws_profile: AWS profile name.
        :param aws_access_key_id: access key ID secret key.
        :param aws_secret_access_key: AWS secret access key.
        """
        if aws_profile is not None:
            if not (aws_secret_access_key is None and aws_access_key_id is None):
                self.logger.error("Cannot use aws_profile with aws_access_key_id or aws_secret_access_key")

        if aws_profile:
            self._auth_mode = 'profile'
            self._aws_profile = aws_profile
        elif aws_access_key_id and aws_secret_access_key:
            self._auth_mode = 'keys'
            self._aws_access_key_id = aws_access_key_id
            self._aws_secret_access_key = aws_secret_access_key
            self._aws_session_token = aws_session_token
        else:
            self._auth_mode = 'environment'
        try:
            sts_client = self._boto_clients.get(
                'sts',
                credential_set='alchemist',
                aws_access_key_id=self._aws_access_key_id,
                aws_secret_access_key=self._aws_secret_access_key,
                aws_session_token=self._aws_session_token,
                profile_name=self._aws_profile,
                region=self.get_default_region()
            )
            account = sts_client.get_caller_identity().get('Account')
        except TaskCatException:
            raise
        except Exception as e:
            try:
                self.logger.warning('Trying GovCloud region.')
                self.set_default_region('us-gov-west-1')
                sts_client = self._boto_clients.get(
                    'sts',
                    credential_set='alchemist',
                    aws_access_key_id=self._aws_access_key_id,
                    aws_secret_access_key=self._aws_secret_access_key,
                    aws_session_token=self._aws_session_token,
                    profile_name=self._aws_profile,
                    region=self.get_default_region()
                )
                account = sts_client.get_caller_identity().get('Account')
            except TaskCatException:
                raise
            except Exception as e:
                self.logger.error("Credential Error - Please check you {}!".format(self._auth_mode))
                self.logger.debug(str(e))
                raise TaskCatException("Credential Error - Please check you {}!".format(self._auth_mode))
        self.logger.info("AWS AccountNumber: \t [%s]" % account)
        self.logger.info("Authenticated via: \t [%s]" % self._auth_mode)

    @staticmethod
    def interface():
        """
        This function creates an argparse parser, parses the arguments, and returns an args object.

        :return: An object from argparse which contains all the args passed in from the command line.
        """
        # Creating Parser
        parser = argparse.ArgumentParser(
            prog="alchemist",
            description="AWS Quick Start rewriter and uploader of assets."
        )
        parser.add_argument(
            "input_path",
            type=str,
            help="the input path of assets to rewrite and/or upload."
        )
        parser.add_argument(
            "-sb",
            "--source-bucket-name",
            type=str,
            help="source S3 bucket name for rewrite."
        )
        parser.add_argument(
            "target_bucket_name",
            type=str,
            help="target S3 bucket name for rewrite and/or upload."
        )
        parser.add_argument(
            "-t",
            "--target-key-prefix",
            type=str,
            help="target S3 key prefix to use. This is required when uploading."
        )
        parser.add_argument(
            "-sp",
            "--source-key-prefix",
            type=str,
            help="source S3 key prefix name for rewrite."
        )
        parser.add_argument(
            "-o",
            "--output-directory",
            type=str,
            help="custom output directory path. If no path is specified, will overwrite current file(s)."
        )
        parser.add_argument(
            "-b",
            "--basic-rewrite",
            action='store_true',
            help="specify to perform a basic rewrite vs. walking the document."
        )
        actions = parser.add_mutually_exclusive_group(required=True)
        actions.add_argument(
            "-u",
            "--upload-only",
            action='store_true',
            help="specify to only upload to S3 (no rewrite)."
        )
        actions.add_argument(
            "-r",
            "--rewrite-only",
            action='store_true',
            help="specify to only rewrite (no upload)."
        )
        actions.add_argument(
            "-ru",
            "--rewrite-and-upload",
            action='store_true',
            help="specify to rewrite and upload to S3."
        )
        parser.add_argument(
            "--convert-key-prefix-to-slashes",
            action='store_true',
            help="specify to convert a quickstart-some-repo/ key prefix to a some/repo/latest/ key prefix."
        )
        parser.add_argument(
            "-p",
            "--aws-profile",
            type=str,
            help="use existing AWS credentials profile."
        )
        parser.add_argument(
            "-a",
            "--aws-access-key-id",
            type=str,
            help="AWS access key ID."
        )
        parser.add_argument(
            "-s",
            "--aws-secret-access-key",
            type=str,
            help="AWS secret access key."
        )
        parser.add_argument(
            "-st",
            "--aws-session-token",
            type=str,
            help="AWS secret access key."
        )
        parser.add_argument(
            "--verbose",
            action='store_true',
            help="specify to enable debug mode logging."
        )
        parser.add_argument(
            "--dry-run",
            action='store_true',
            help="specify to simulate the rewrite and upload actions to learn what would happen."
        )

        args = parser.parse_args()

        if args.aws_profile is not None:
            if not (args.aws_secret_access_key is None and args.aws_access_key_id is None):
                parser.error("Cannot use -p/--aws-profile with -a/--aws-access-key-id or -s/--aws-secret-access-key")

        if args.upload_only and args.output_directory:
            parser.error("Upload only mode does not use an output directory")

        if args.upload_only or args.rewrite_and_upload:
            if args.target_key_prefix is None:
                parser.error("-t/--target-key-prefix must be provided when uploading is specified (-u/--upload-only or -ru/--rewrite-and-upload")

        if args.convert_key_prefix_to_slashes:
            args.target_key_prefix = CFNAlchemist.aws_quickstart_s3_key_prefix_builder(args.target_key_prefix)

        return args

    @staticmethod
    def aws_quickstart_s3_key_prefix_builder(repo_name):
        """
        This converts a quickstart-some-repo/ key prefix string to a some/repo/latest/ key prefix string.

        :return: An object from argparse which contains all the args passed in from the command line.
        """
        # Determine S3 path from a valid git repo name
        # Remove quickstart-, change dashes to slashes, and add /latest
        repo_path = os.path.join(repo_name.replace('quickstart-', '', 1).replace('-', '/'), 'latest/')

        # EXCEPTIONS (that we have to live with for now):
        # enterprise-accelerator
        repo_path = repo_path.replace('enterprise/accelerator', 'enterprise-accelerator', 1)
        # nist-high
        repo_path = repo_path.replace('nist/high', 'nist-high', 1)
        # chef-server
        repo_path = repo_path.replace('chefserver', 'chef-server', 1)
        print("[INFO]: Converted repo name [" + str(repo_name) + "] to S3 path [" + str(repo_path) + "]")
        return repo_path
Esempio n. 7
0
def client_factory_instance():
    with mock.patch.object(ClientFactory, '__init__', return_value=None):
        aws_clients = ClientFactory(None)
    aws_clients._credential_sets = {'default': [None, None, None, None]}
    return aws_clients