def trigger_state_machine(self):
        try:
            self.logger.info("Executing: " + self.__class__.__name__ + "/" +
                             inspect.stack()[0][3])
            sm = StateMachine(self.logger)
            account_id = self.event.get('account')
            resource_type = 'stno-console' if self.event.get('detail', {}).get('resource-type') is None \
                else account_id + '-' + self.event.get('detail', {}).get('resource-type') + '-tagged'
            state_machine_arn = environ.get('STATE_MACHINE_ARN')

            # Execute State Machine

            exec_name = "%s-%s-%s" % ('event-from', resource_type,
                                      time.strftime("%Y-%m-%dT%H-%M-%S-%s"))
            self.event.update({'StateMachineArn': state_machine_arn})

            self.logger.info("Triggering {} State Machine".format(
                state_machine_arn.split(":", 6)[6]))
            response = sm.trigger_state_machine(state_machine_arn, self.event,
                                                sanitize(exec_name))
            self.logger.info(
                "State machine triggered successfully, Execution Arn: {}".
                format(response))
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'CLASS': self.__class__.__name__,
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
예제 #2
0
    def trigger_state_machine(self):
        try:
            self.logger.info("Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3])
            sm = StateMachine(self.logger)
            resource_type = self.event.get('ResourceType')
            request_type = self.event.get('RequestType')

            if resource_type == 'Custom::Organizations' and environ.get('sm_arn_account'):
                state_machine_arn = environ.get('sm_arn_account')
            elif resource_type == 'Custom::ServiceControlPolicy' and environ.get('sm_arn_service_control_policy'):
                state_machine_arn = environ.get('sm_arn_service_control_policy')
            elif resource_type == 'Custom::StackInstance' and environ.get('sm_arn_stack_set'):
                state_machine_arn = environ.get('sm_arn_stack_set')
            elif resource_type == 'Custom::CheckAVMExistsForAccount' and environ.get('sm_arn_check_avm_exists'):
                state_machine_arn = environ.get('sm_arn_check_avm_exists')
            elif resource_type == 'Custom::ADConnector' and environ.get('sm_arn_ad_connector'):
                state_machine_arn = environ.get('sm_arn_ad_connector')
            elif resource_type == 'Custom::HandShakeStateMachine' and environ.get('sm_arn_handshake_sm'):
                state_machine_arn = environ.get('sm_arn_handshake_sm')
            else:
                self.logger.error("ResourceType Not Supported {} or Env. Variable not found".format(resource_type))
                raise Exception("ResourceType Not Supported {} or Env. Variable not found".format(resource_type))

            # Execute State Machine
            if resource_type == 'Custom::StackInstance':
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type,
                                             trim_length(self.event.get('ResourceProperties', {}).get('StackSetName'), 45),
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
            elif resource_type == 'Custom::Organizations':
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type,
                                             trim_length(self.event.get('ResourceProperties', {}).get('OUName') + '-' +
                                                         self.event.get('ResourceProperties', {}).get('AccountName'), 45),
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
            elif resource_type == 'Custom::ServiceControlPolicy':
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type,
                                             trim_length(self.event.get('ResourceProperties', {}).get('Operation'), 45),
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
            elif resource_type == 'Custom::HandShakeStateMachine':
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type,
                                             trim_length(self.event.get('ResourceProperties', {}).get('ServiceType') +
                                                         '-' + self.event.get('ResourceProperties', {}).get('SpokeRegion')
                                                         + '-' + self.event.get('ResourceProperties', {}).get('SpokeAccountId'), 45),
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
            elif resource_type == 'Custom::CheckAVMExistsForAccount':
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type,
                                             trim_length(self.event.get('ResourceProperties', {}).get('ProdParams', {}).get('OUName') + '-' +
                                             self.event.get('ResourceProperties', {}).get('ProdParams', {}).get('AccountName'), 45),
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
            else:
                exec_name = "%s-%s-%s-%s" % ('AVM-CR', request_type, resource_type.replace("Custom::", ""),
                                             time.strftime("%Y-%m-%dT%H-%M-%S-%s"))
            self.event.update({'StateMachineArn': state_machine_arn})
            self.logger.info("Triggering {} State Machine".format(state_machine_arn.split(":", 6)[6]))
            response = sm.trigger_state_machine(state_machine_arn, self.event, sanitize(exec_name))
            self.logger.info("State machine triggered successfully, Execution Arn: {}".format(response))
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'CLASS': self.__class__.__name__,
                       'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise
예제 #3
0
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_launch_avm):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_launch_avm = sm_arn_launch_avm
     self.manifest = None
     self.list_sm_exec_arns = []
예제 #4
0
 def __init__(self, logger, sm_input_list):
     self.logger = logger
     self.sm_input_list = sm_input_list
     self.list_sm_exec_arns = []
     self.stack_set_exist = True
     self.solution_metrics = Metrics(logger)
     self.param_handler = CFNParamsHandler(logger)
     self.state_machine = StateMachine(logger)
     self.stack_set = StackSet(logger)
     self.s3 = S3(logger)
     self.wait_time = os.environ.get('WAIT_TIME')
     self.execution_mode = os.environ.get('EXECUTION_MODE')
예제 #5
0
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_launch_avm, batch_size):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_launch_avm = sm_arn_launch_avm
     self.manifest = None
     self.list_sm_exec_arns = []
     self.batch_size = batch_size
     self.avm_product_name = None
     self.avm_portfolio_name = None
     self.avm_params = None
     self.root_id = None
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_scp,
              staging_bucket):
     self.state_machine = StateMachine(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_scp = sm_arn_scp
     self.manifest = None
     self.list_sm_exec_arns = []
     self.nested_ou_delimiter = ""
     self.staging_bucket = staging_bucket
     self.root_id = None
예제 #7
0
파일: ibs_agi.py 프로젝트: sankopay/IBSng
def initStateMachine():
    global state_machine
    from lib.state_machine import StateMachine
    state_machine = StateMachine("MAIN")
    from lib.plugin_loader import PluginLoader
    plugin_loader = PluginLoader()
    plugin_loader.initPlugins(getConfig().getValue("state_plugins"))
 def __init__(self, logger, sm_arns_map, staging_bucket, manifest_file_path,
              pipeline_stage, token, execution_mode, primary_account_id):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.sm_arns_map = sm_arns_map
     self.manifest = None
     self.staging_bucket = staging_bucket
     self.manifest_file_path = manifest_file_path
     self.token = token
     self.pipeline_stage = pipeline_stage
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     if execution_mode.lower() == 'sequential':
         self.isSequential = True
     else:
         self.isSequential = False
     self.index = 100
     self.primary_account_id = primary_account_id
예제 #9
0
 def test_next(self):
     transitions = {
         "1": "1.3",
         "1.3": "1.6",
         "1.6": "2",
         "2": "2.3",
         "2.3": "2.6",
         "2.6": "1"
     }
     machine = StateMachine(transitions, "1", False)
     steps = 0
     for x in ["1", "1.3", "1.6", "2", "2.3", "2.6"]:
         self.assertEqual(
             machine.step,
             x
         )
         self.assertEqual(
             machine.steps,
             steps
         )
         steps = steps + 1
         machine.next()
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_stackset, staging_bucket, execution_mode):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_stackset = sm_arn_stackset
     self.manifest = None
     self.list_sm_exec_arns = []
     self.staging_bucket = staging_bucket
     self.root_id = None
     self.uuid = uuid4()
     self.state_machine_event = {}
     if execution_mode.lower() == 'sequential':
         self.logger.info("Running {} mode".format(execution_mode))
         self.sequential_flag = True
     else:
         self.logger.info("Running {} mode".format(execution_mode))
         self.sequential_flag = False
예제 #11
0
 def __init__(self, logger, wait_time, manifest_file_path,
              sm_arn_launch_avm, batch_size):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.sc = SC(logger)
     self.param_handler = CFNParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_launch_avm = sm_arn_launch_avm
     self.manifest = None
     self.list_sm_exec_arns = []
     self.batch_size = batch_size
     self.avm_product_name = None
     self.avm_product_id = None
     self.avm_artifact_id = None
     self.avm_params = None
     self.root_id = None
     self.sc_portfolios = {}
     self.sc_products = {}
     self.provisioned_products = {}  # [productid] = []
     self.provisioned_products_by_account = {
     }  # [account] = [] list of ppids
예제 #12
0
if len(sys.argv) == 1:
    Config.print_help(sys.stderr)
    os._exit(1)

config = Config()
balance = config.load_json("balance.json")

leftExc, rightExc = args.exchanges

leftClient = clients[leftExc](args.real, balance[leftExc], args.pair)
rightClient = clients[rightExc](args.real, balance[rightExc], args.pair)

strategy = importlib.import_module('strategies.' + args.strategy)
transitions = strategy.transitions
stateMachine = StateMachine(
    transitions,
    args.current_step if args.current_step else next(iter(transitions)))
leftClient.subscribe_to_order_filled(stateMachine.next)
rightClient.subscribe_to_order_filled(stateMachine.next)


def on_stream_value(v):
    left, right = v
    left_price = left['price']
    right_price = right['price']

    if stateMachine.steps == args.steps:
        os._exit(1)

    if args.log:
        print(
예제 #13
0
class StateMachineTriggerLambda(object):
    def __init__(self, logger, sm_arns_map, staging_bucket, manifest_file_path,
                 pipeline_stage, token, execution_mode, primary_account_id):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.sm_arns_map = sm_arns_map
        self.manifest = None
        self.staging_bucket = staging_bucket
        self.manifest_file_path = manifest_file_path
        self.token = token
        self.pipeline_stage = pipeline_stage
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        if execution_mode.lower() == 'sequential':
            self.isSequential = True
        else:
            self.isSequential = False
        self.index = 100
        self.primary_account_id = primary_account_id

    def _save_sm_exec_arn(self, list_sm_exec_arns):
        if list_sm_exec_arns is not None and type(list_sm_exec_arns) is list:
            self.logger.debug(
                "Saving the token:{} with list of sm_exec_arns:{}".format(
                    self.token, list_sm_exec_arns))
            if len(list_sm_exec_arns) > 0:
                sm_exec_arns = ",".join(
                    list_sm_exec_arns
                )  # Create comma seperated string from list e.g. 'a','b','c'
                self.ssm.put_parameter(
                    self.token,
                    sm_exec_arns)  # Store the list of SM execution ARNs in SSM
            else:
                self.ssm.put_parameter(self.token, 'PASS')
        else:
            raise Exception(
                "Expecting a list of state machine execution ARNs to store in SSM for token:{}, but found nothing to store."
                .format(self.token))

    def _stage_template(self, relative_template_path):
        if relative_template_path.lower().startswith('s3'):
            # Convert the remote template URL s3://bucket-name/object
            # to Virtual-hosted style URL https://bucket-name.s3.amazonaws.com/object
            t = relative_template_path.split("/", 3)
            s3_url = "https://{}.s3.amazonaws.com/{}".format(t[2], t[3])
        else:
            local_file = os.path.join(self.manifest_folder,
                                      relative_template_path)
            remote_file = "{}/{}_{}".format(
                TEMPLATE_KEY_PREFIX, self.token,
                relative_template_path[relative_template_path.rfind('/') + 1:])
            logger.info(
                "Uploading the template file: {} to S3 bucket: {} and key: {}".
                format(local_file, self.staging_bucket, remote_file))
            self.s3.upload_file(self.staging_bucket, local_file, remote_file)
            s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/',
                                       self.staging_bucket, '/', remote_file)
        return s3_url

    def _download_remote_file(self, remote_s3_path):
        _file = tempfile.mkstemp()[1]
        t = remote_s3_path.split("/", 3)  # s3://bucket-name/key
        remote_bucket = t[2]  # Bucket name
        remote_key = t[3]  # Key
        logger.info("Downloading {}/{} from S3 to {}".format(
            remote_bucket, remote_key, _file))
        self.s3.download_file(remote_bucket, remote_key, _file)
        return _file

    def _load_policy(self, relative_policy_path):
        if relative_policy_path.lower().startswith('s3'):
            policy_file = self._download_remote_file(relative_policy_path)
        else:
            policy_file = os.path.join(self.manifest_folder,
                                       relative_policy_path)

        logger.info("Parsing the policy file: {}".format(policy_file))

        with open(policy_file, 'r') as content_file:
            policy_file_content = content_file.read()

        #Check if valid json
        json.loads(policy_file_content)
        #Return the Escaped JSON text
        return policy_file_content.replace('"', '\"').replace('\n', '\r\n')

    def _load_params(self, relative_parameter_path, account=None, region=None):
        if relative_parameter_path.lower().startswith('s3'):
            parameter_file = self._download_remote_file(
                relative_parameter_path)
        else:
            parameter_file = os.path.join(self.manifest_folder,
                                          relative_parameter_path)

        logger.info("Parsing the parameter file: {}".format(parameter_file))

        with open(parameter_file, 'r') as content_file:
            parameter_file_content = content_file.read()

        params = json.loads(parameter_file_content)
        if account is not None:
            #Deploying Core resource Stack Set
            # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
            sm_params = self.param_handler.update_params(
                params, account, region, False)
        else:
            # Deploying Baseline resource Stack Set
            sm_params = self.param_handler.update_params(params)

        logger.info("Input Parameters for State Machine: {}".format(sm_params))
        return sm_params

    def _load_template_rules(self, relative_rules_path):
        rules_file = os.path.join(self.manifest_folder, relative_rules_path)
        logger.info("Parsing the template rules file: {}".format(rules_file))

        with open(rules_file, 'r') as content_file:
            rules_file_content = content_file.read()

        rules = json.loads(rules_file_content)

        logger.info(
            "Template Constraint Rules for State Machine: {}".format(rules))

        return rules

    def _populate_ssm_params(self, sm_input):
        # The scenario is if you have one core resource that exports output from CFN stack to SSM parameter
        # and then the next core resource reads the SSM parameter as input, then it has to wait for the first core resource to
        # finish; read the SSM parameters and use its value as input for second core resource's input for SM
        # Get the parameters for CFN template from sm_input
        logger.debug("Populating SSM parameter values for SM input: {}".format(
            sm_input))
        params = sm_input.get('ResourceProperties').get('Parameters', {})
        # First transform it from {name: value} to [{'ParameterKey': name}, {'ParameterValue': value}]
        # then replace the SSM parameter names with its values
        sm_params = self.param_handler.update_params(transform_params(params))
        # Put it back into the sm_input
        sm_input.get('ResourceProperties').update({'Parameters': sm_params})
        logger.debug(
            "Done populating SSM parameter values for SM input: {}".format(
                sm_input))
        return sm_input

    def _create_ssm_input_map(self, ssm_parameters):
        ssm_input_map = {}

        for ssm_parameter in ssm_parameters:
            key = ssm_parameter.name
            value = ssm_parameter.value
            ssm_value = self.param_handler.update_params(
                transform_params({key: value}))
            ssm_input_map.update(ssm_value)

        return ssm_input_map

    def _create_state_machine_input_map(self,
                                        input_params,
                                        request_type='Create'):
        request = {}
        request.update({'RequestType': request_type})
        request.update({'ResourceProperties': input_params})

        return request

    def _create_account_state_machine_input_map(self,
                                                ou_name,
                                                account_name='',
                                                account_email='',
                                                ssm_map=None):
        input_params = {}
        input_params.update({'OUName': ou_name})
        input_params.update({'AccountName': account_name})
        input_params.update({'AccountEmail': account_email})
        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})
        return self._create_state_machine_input_map(input_params)

    def _create_stack_set_state_machine_input_map(
            self,
            stack_set_name,
            template_url,
            parameters,
            account_list=[],
            regions_list=[],
            ssm_map=None,
            capabilities='CAPABILITY_NAMED_IAM'):
        input_params = {}
        input_params.update({'StackSetName': sanitize(stack_set_name)})
        input_params.update({'TemplateURL': template_url})
        input_params.update({'Parameters': parameters})
        input_params.update({'Capabilities': capabilities})

        if len(account_list) > 0:
            input_params.update({'AccountList': account_list})
            if len(regions_list) > 0:
                input_params.update({'RegionList': regions_list})
            else:
                input_params.update({'RegionList': [self.manifest.region]})
        else:
            input_params.update({'AccountList': ''})
            input_params.update({'RegionList': ''})

        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})

        return self._create_state_machine_input_map(input_params)

    def _create_service_control_policy_state_machine_input_map(
            self, policy_name, policy_content, policy_desc=''):
        input_params = {}
        policy_doc = {}
        policy_doc.update({'Name': sanitize(policy_name)})
        policy_doc.update({'Description': policy_desc})
        policy_doc.update({'Content': policy_content})
        input_params.update({'PolicyDocument': policy_doc})
        input_params.update({'AccountId': ''})
        input_params.update({'PolicyList': []})
        input_params.update({'Operation': ''})
        return self._create_state_machine_input_map(input_params)

    def _create_service_catalog_state_machine_input_map(
            self, portfolio, product):
        input_params = {}

        sc_portfolio = {}
        sc_portfolio.update({'PortfolioName': sanitize(portfolio.name, True)})
        sc_portfolio.update(
            {'PortfolioDescription': sanitize(portfolio.description, True)})
        sc_portfolio.update(
            {'PortfolioProvider': sanitize(portfolio.owner, True)})
        ssm_value = self.param_handler.update_params(
            transform_params({'principal_role': portfolio.principal_role}))
        sc_portfolio.update({'PrincipalArn': ssm_value.get('principal_role')})

        sc_product = {}
        sc_product.update({'ProductName': sanitize(product.name, True)})
        sc_product.update({'ProductDescription': product.description})
        sc_product.update({'ProductOwner': sanitize(portfolio.owner, True)})
        if product.hide_old_versions is True:
            sc_product.update({'HideOldVersions': 'Yes'})
        else:
            sc_product.update({'HideOldVersions': 'No'})
        ssm_value = self.param_handler.update_params(
            transform_params(
                {'launch_constraint_role': product.launch_constraint_role}))
        sc_product.update({'RoleArn': ssm_value.get('launch_constraint_role')})

        ec2 = EC2(self.logger, environ.get('AWS_REGION'))
        region_list = []
        for region in ec2.describe_regions():
            region_list.append(region.get('RegionName'))

        if os.path.isfile(
                os.path.join(self.manifest_folder, product.skeleton_file)):
            lambda_arn_param = get_env_var('lambda_arn_param_name')
            lambda_arn = self.ssm.get_parameter(lambda_arn_param)
            portfolio_index = self.manifest.portfolios.index(portfolio)
            product_index = self.manifest.portfolios[
                portfolio_index].products.index(product)
            product_name = self.manifest.portfolios[portfolio_index].products[
                product_index].name
            logger.info(
                "Generating the product template for {} from {}".format(
                    product_name,
                    os.path.join(self.manifest_folder, product.skeleton_file)))
            j2loader = jinja2.FileSystemLoader(self.manifest_folder)
            j2env = jinja2.Environment(loader=j2loader)
            j2template = j2env.get_template(product.skeleton_file)
            template_url = None
            if product.product_type.lower() == 'baseline':
                # j2result = j2template.render(manifest=self.manifest, portfolio_index=portfolio_index,
                #                              product_index=product_index, lambda_arn=lambda_arn, uuid=uuid.uuid4(),
                #                              regions=region_list)
                template_url = self._stage_template(product.skeleton_file +
                                                    ".template")
            elif product.product_type.lower() == 'optional':
                if len(product.template_file) > 0:
                    template_url = self._stage_template(product.template_file)
                    j2result = j2template.render(
                        manifest=self.manifest,
                        portfolio_index=portfolio_index,
                        product_index=product_index,
                        lambda_arn=lambda_arn,
                        uuid=uuid.uuid4(),
                        template_url=template_url)
                    generated_avm_template = os.path.join(
                        self.manifest_folder,
                        product.skeleton_file + ".generated.template")
                    logger.info(
                        "Writing the generated product template to {}".format(
                            generated_avm_template))
                    with open(generated_avm_template, "w") as fh:
                        fh.write(j2result)
                    template_url = self._stage_template(generated_avm_template)
                else:
                    raise Exception(
                        "Missing template_file location for portfolio:{} and product:{} in Manifest file"
                        .format(portfolio.name, product.name))

        else:
            raise Exception(
                "Missing skeleton_file for portfolio:{} and product:{} in Manifest file"
                .format(portfolio.name, product.name))

        artifact_params = {}
        artifact_params.update({'Info': {'LoadTemplateFromURL': template_url}})
        artifact_params.update({'Type': 'CLOUD_FORMATION_TEMPLATE'})
        artifact_params.update({'Description': product.description})
        sc_product.update({'ProvisioningArtifactParameters': artifact_params})

        try:
            if product.rules_file:
                rules = self._load_template_rules(product.rules_file)
                sc_product.update({'Rules': rules})
        except Exception as e:
            logger.error(e)

        input_params.update({'SCPortfolio': sc_portfolio})
        input_params.update({'SCProduct': sc_product})

        return self._create_state_machine_input_map(input_params)

    def _create_launch_avm_state_machine_input_map(self, portfolio, product,
                                                   accounts):
        input_params = {}
        input_params.update({'PortfolioName': sanitize(portfolio, True)})
        input_params.update({'ProductName': sanitize(product, True)})
        input_params.update({'ProvisioningParametersList': accounts})
        return self._create_state_machine_input_map(input_params)

    def _run_or_queue_state_machine(self, sm_input, sm_arn, list_sm_exec_arns,
                                    sm_name):
        logger.info("State machine Input: {}".format(sm_input))
        exec_name = "%s-%s-%s" % (sm_input.get('RequestType'),
                                  sm_name.replace(" ", ""),
                                  time.strftime("%Y-%m-%dT%H-%M-%S"))
        # If Sequential, kick off the first SM, and save the state machine input JSON
        # for the rest in SSM parameter store under /job_id/0 tree
        if self.isSequential:
            if self.index == 100:
                sm_input = self._populate_ssm_params(sm_input)
                sm_exec_arn = self.state_machine.trigger_state_machine(
                    sm_arn, sm_input, exec_name)
                list_sm_exec_arns.append(sm_exec_arn)
            else:
                param_name = "/%s/%s" % (self.token, self.index)
                self.ssm.put_parameter(param_name, json.dumps(sm_input))
        # Else if Parallel, execute all SM at regular interval of wait_time
        else:
            sm_input = self._populate_ssm_params(sm_input)
            sm_exec_arn = self.state_machine.trigger_state_machine(
                sm_arn, sm_input, exec_name)
            time.sleep(int(wait_time))  # Sleeping for sometime
            list_sm_exec_arns.append(sm_exec_arn)
        self.index = self.index + 1

    def _deploy_resource(self,
                         resource,
                         sm_arn,
                         list_sm_exec_arns,
                         account_id=None):
        template_full_path = self._stage_template(resource.template_file)
        params = {}
        deploy_resource_flag = True
        if resource.parameter_file:
            if len(resource.regions) > 0:
                params = self._load_params(resource.parameter_file, account_id,
                                           resource.regions[0])
            else:
                params = self._load_params(resource.parameter_file, account_id,
                                           self.manifest.region)

        ssm_map = self._create_ssm_input_map(resource.ssm_parameters)

        if account_id is not None:
            #Deploying Core resource Stack Set
            stack_name = "AWS-Landing-Zone-{}".format(resource.name)
            sm_input = self._create_stack_set_state_machine_input_map(
                stack_name, template_full_path, params, [str(account_id)],
                resource.regions, ssm_map)
        else:
            #Deploying Baseline resource Stack Set
            stack_name = "AWS-Landing-Zone-Baseline-{}".format(resource.name)
            sm_input = self._create_stack_set_state_machine_input_map(
                stack_name, template_full_path, params, [], [], ssm_map)

            stack_set = StackSet(self.logger)
            response = stack_set.describe_stack_set(stack_name)
            if response is not None:
                self.logger.info("Found existing stack set.")
                self.logger.info(
                    "Comparing the template of the StackSet: {} with local copy of template"
                    .format(stack_name))
                relative_template_path = resource.template_file
                if relative_template_path.lower().startswith('s3'):
                    local_template_file = self._download_remote_file(
                        relative_template_path)
                else:
                    local_template_file = os.path.join(self.manifest_folder,
                                                       relative_template_path)

                cfn_template_file = tempfile.mkstemp()[1]
                with open(cfn_template_file, "w") as f:
                    f.write(response.get('StackSet').get('TemplateBody'))

                template_compare = filecmp.cmp(local_template_file,
                                               cfn_template_file)
                self.logger.info(
                    "Comparing the parameters of the StackSet: {} with local copy of JSON parameters file"
                    .format(stack_name))
                params_compare = True
                if template_compare:
                    cfn_params = reverse_transform_params(
                        response.get('StackSet').get('Parameters'))
                    for key, value in params.items():
                        if cfn_params.get(key, '') == value:
                            pass
                        else:
                            params_compare = False
                            break

                self.logger.info(
                    "template_compare={}".format(template_compare))
                self.logger.info("params_compare={}".format(params_compare))
                if template_compare and params_compare:
                    deploy_resource_flag = False
                    self.logger.info(
                        "Found no changes in template & parameters, so skipping Update StackSet for {}"
                        .format(stack_name))

        if deploy_resource_flag:
            self._run_or_queue_state_machine(sm_input, sm_arn,
                                             list_sm_exec_arns, stack_name)

    def start_core_account_sm(self, sm_arn_account):
        try:
            logger.info("Setting the lock_down_stack_sets_role={}".format(
                self.manifest.lock_down_stack_sets_role))

            if self.manifest.lock_down_stack_sets_role is True:
                self.ssm.put_parameter('lock_down_stack_sets_role_flag', 'yes')
            else:
                self.ssm.put_parameter('lock_down_stack_sets_role_flag', 'no')

            # Send metric - pipeline run count
            data = {"PipelineRunCount": "1"}
            self.send.metrics(data)

            logger.info("Processing Core Accounts from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            for ou in self.manifest.organizational_units:
                ou_name = ou.name
                logger.info(
                    "Generating the state machine input json for OU: {}".
                    format(ou_name))

                if len(ou.core_accounts) == 0:
                    # Empty OU with no Accounts
                    sm_input = self._create_account_state_machine_input_map(
                        ou_name)
                    self._run_or_queue_state_machine(sm_input, sm_arn_account,
                                                     list_sm_exec_arns,
                                                     ou_name)

                for account in ou.core_accounts:
                    account_name = account.name

                    if account_name.lower() == 'primary':
                        account_email = ''
                    else:
                        account_email = account.email
                        if not account_email:
                            raise Exception(
                                "Failed to retrieve the email address for the Account: {}"
                                .format(account_name))

                    ssm_map = self._create_ssm_input_map(
                        account.ssm_parameters)

                    sm_input = self._create_account_state_machine_input_map(
                        ou_name, account_name, account_email, ssm_map)
                    self._run_or_queue_state_machine(sm_input, sm_arn_account,
                                                     list_sm_exec_arns,
                                                     account_name)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_core_resource_sm(self, sm_arn_stack_set):
        try:
            logger.info("Parsing Core Resources from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for ou in self.manifest.organizational_units:
                for account in ou.core_accounts:
                    account_name = account.name
                    account_id = ''
                    for ssm_parameter in account.ssm_parameters:
                        if ssm_parameter.value == '$[AccountId]':
                            account_id = self.ssm.get_parameter(
                                ssm_parameter.name)

                    if account_id == '':
                        raise Exception(
                            "Missing required SSM parameter: {} to retrive the account Id of Account: {} defined in Manifest"
                            .format(ssm_parameter.name, account_name))

                    for resource in account.core_resources:
                        # Count number of stacksets
                        count += 1
                        if resource.deploy_method.lower() == 'stack_set':
                            self._deploy_resource(resource, sm_arn_stack_set,
                                                  list_sm_exec_arns,
                                                  account_id)
                        else:
                            raise Exception(
                                "Unsupported deploy_method: {} found for resource {} and Account: {} in Manifest"
                                .format(resource.deploy_method, resource.name,
                                        account_name))
            data = {"CoreAccountStackSetCount": str(count)}
            self.send.metrics(data)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_control_policy_sm(self, sm_arn_scp):
        try:
            logger.info("Processing SCPs from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for policy in self.manifest.organization_policies:
                policy_content = self._load_policy(policy.policy_file)
                sm_input = self._create_service_control_policy_state_machine_input_map(
                    policy.name, policy_content, policy.description)
                self._run_or_queue_state_machine(sm_input, sm_arn_scp,
                                                 list_sm_exec_arns,
                                                 policy.name)
                # Count number of stacksets
                count += 1
            self._save_sm_exec_arn(list_sm_exec_arns)
            data = {"SCPPolicyCount": str(count)}
            self.send.metrics(data)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_catalog_sm(self, sm_arn_sc):
        try:
            logger.info(
                "Processing Service catalogs section from {} file".format(
                    self.manifest_file_path))
            list_sm_exec_arns = []
            for portfolio in self.manifest.portfolios:
                for product in portfolio.products:
                    sm_input = self._create_service_catalog_state_machine_input_map(
                        portfolio, product)
                    self._run_or_queue_state_machine(sm_input, sm_arn_sc,
                                                     list_sm_exec_arns,
                                                     product.name)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_baseline_resources_sm(self, sm_arn_stack_set):
        try:
            logger.info("Parsing Basline Resources from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for resource in self.manifest.baseline_resources:
                if resource.deploy_method.lower() == 'stack_set':
                    self._deploy_resource(resource, sm_arn_stack_set,
                                          list_sm_exec_arns)
                    # Count number of stacksets
                    count += 1
                else:
                    raise Exception(
                        "Unsupported deploy_method: {} found for resource {} in Manifest"
                        .format(resource.deploy_method, resource.name))
            data = {"BaselineStackSetCount": str(count)}
            self.send.metrics(data)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_launch_avm(self, sm_arn_launch_avm):
        try:
            logger.info("Starting the launch AVM trigger")
            list_sm_exec_arns = []
            ou_id_map = {}

            org = Organizations(self.logger)
            response = org.list_roots()
            self.logger.info("List roots Response")
            self.logger.info(response)
            root_id = response['Roots'][0].get('Id')

            response = org.list_organizational_units_for_parent(
                ParentId=root_id)
            next_token = response.get('NextToken', None)

            for ou in response['OrganizationalUnits']:
                ou_id_map.update({ou.get('Name'): ou.get('Id')})

            while next_token is not None:
                response = org.list_organizational_units_for_parent(
                    ParentId=root_id, NextToken=next_token)
                next_token = response.get('NextToken', None)
                for ou in response['OrganizationalUnits']:
                    ou_id_map.update({ou.get('Name'): ou.get('Id')})

            self.logger.info("ou_id_map={}".format(ou_id_map))

            for portfolio in self.manifest.portfolios:
                for product in portfolio.products:
                    if product.product_type.lower() == 'baseline':
                        _params = self._load_params(product.parameter_file)
                        logger.info(
                            "Input parameters format for AVM: {}".format(
                                _params))
                        list_of_accounts = []
                        for ou in product.apply_baseline_to_accounts_in_ou:
                            self.logger.debug(
                                "Looking up ou={} in ou_id_map".format(ou))
                            ou_id = ou_id_map.get(ou)
                            self.logger.debug(
                                "ou_id={} for ou={} in ou_id_map".format(
                                    ou_id, ou))

                            response = org.list_accounts_for_parent(ou_id)
                            self.logger.debug(
                                "List Accounts for Parent Response")
                            self.logger.debug(response)
                            for account in response.get('Accounts'):
                                params = _params.copy()
                                for key, value in params.items():
                                    if value.lower() == 'accountemail':
                                        params.update(
                                            {key: account.get('Email')})
                                    elif value.lower() == 'accountname':
                                        params.update(
                                            {key: account.get('Name')})
                                    elif value.lower() == 'orgunitname':
                                        params.update({key: ou})

                                logger.info(
                                    "Input parameters format for Account: {} are {}"
                                    .format(account.get('Name'), params))

                                list_of_accounts.append(params)

                        if len(list_of_accounts) > 0:
                            sm_input = self._create_launch_avm_state_machine_input_map(
                                portfolio.name, product.name, list_of_accounts)
                            logger.info(
                                "Launch AVM state machine Input: {}".format(
                                    sm_input))
                            exec_name = "%s-%s-%s" % (
                                sm_input.get('RequestType'), "Launch-AVM",
                                time.strftime("%Y-%m-%dT%H-%M-%S"))
                            sm_exec_arn = self.state_machine.trigger_state_machine(
                                sm_arn_launch_avm, sm_input, exec_name)
                            list_sm_exec_arns.append(sm_exec_arn)

                    time.sleep(int(wait_time))  # Sleeping for sometime
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def trigger_state_machines(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)

            if self.pipeline_stage == 'core_accounts':
                self.start_core_account_sm(self.sm_arns_map.get('account'))
            elif self.pipeline_stage == 'core_resources':
                self.start_core_resource_sm(self.sm_arns_map.get('stack_set'))
            elif self.pipeline_stage == 'service_control_policy':
                self.start_service_control_policy_sm(
                    self.sm_arns_map.get('service_control_policy'))
            elif self.pipeline_stage == 'service_catalog':
                self.start_service_catalog_sm(
                    self.sm_arns_map.get('service_catalog'))
            elif self.pipeline_stage == 'baseline_resources':
                self.start_baseline_resources_sm(
                    self.sm_arns_map.get('stack_set'))
            elif self.pipeline_stage == 'launch_avm':
                self.start_launch_avm(self.sm_arns_map.get('launch_avm'))

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def get_state_machines_execution_status(self):
        try:
            sm_exec_arns = self.ssm.get_parameter(self.token)

            if sm_exec_arns == 'PASS':
                self.ssm.delete_parameter(self.token)
                return 'SUCCEEDED', ''
            else:
                list_sm_exec_arns = sm_exec_arns.split(
                    ","
                )  # Create a list from comma seperated string e.g. ['a','b','c']

                for sm_exec_arn in list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(
                        sm_exec_arn)
                    if status == 'RUNNING':
                        return 'RUNNING', ''
                    elif status == 'SUCCEEDED':
                        continue
                    else:
                        self.ssm.delete_parameter(self.token)
                        self.ssm.delete_parameters_by_path(self.token)
                        err_msg = "State Machine Execution Failed, please check the Step function console for State Machine Execution ARN: {}".format(
                            sm_exec_arn)
                        return 'FAILED', err_msg

                if self.isSequential:
                    _params_list = self.ssm.get_parameters_by_path(self.token)
                    if _params_list:
                        params_list = sorted(_params_list,
                                             key=lambda i: i['Name'])
                        sm_input = json.loads(params_list[0].get('Value'))
                        if self.pipeline_stage == 'core_accounts':
                            sm_arn = self.sm_arns_map.get('account')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'OUName') + "-" + sm_input.get(
                                    'ResourceProperties').get('AccountName')

                            account_name = sm_input.get(
                                'ResourceProperties').get('AccountName')
                            if account_name.lower() == 'primary':
                                org = Organizations(self.logger)
                                response = org.describe_account(
                                    self.primary_account_id)
                                account_email = response.get('Account').get(
                                    'Email', '')
                                sm_input.get('ResourceProperties').update(
                                    {'AccountEmail': account_email})

                        elif self.pipeline_stage == 'core_resources':
                            sm_arn = self.sm_arns_map.get('stack_set')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'StackSetName')
                            sm_input = self._populate_ssm_params(sm_input)
                        elif self.pipeline_stage == 'service_control_policy':
                            sm_arn = self.sm_arns_map.get(
                                'service_control_policy')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'PolicyDocument').get('Name')
                        elif self.pipeline_stage == 'service_catalog':
                            sm_arn = self.sm_arns_map.get('service_catalog')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'SCProduct').get('ProductName')
                        elif self.pipeline_stage == 'baseline_resources':
                            sm_arn = self.sm_arns_map.get('stack_set')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'StackSetName')
                            sm_input = self._populate_ssm_params(sm_input)

                        exec_name = "%s-%s-%s" % (sm_input.get(
                            'RequestType'), sm_name.replace(
                                " ", ""), time.strftime("%Y-%m-%dT%H-%M-%S"))
                        sm_exec_arn = self.state_machine.trigger_state_machine(
                            sm_arn, sm_input, exec_name)
                        self._save_sm_exec_arn([sm_exec_arn])
                        self.ssm.delete_parameter(params_list[0].get('Name'))
                        return 'RUNNING', ''

                self.ssm.delete_parameter(self.token)
                return 'SUCCEEDED', ''

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
class DeployStackSetStateMachine(object):
    def __init__(self, logger, wait_time, manifest_file_path, sm_arn_stackset, staging_bucket, execution_mode):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_stackset = sm_arn_stackset
        self.manifest = None
        self.list_sm_exec_arns = []
        self.staging_bucket = staging_bucket
        self.root_id = None
        self.uuid = uuid4()
        self.state_machine_event = {}
        if execution_mode.lower() == 'sequential':
            self.logger.info("Running {} mode".format(execution_mode))
            self.sequential_flag = True
        else:
            self.logger.info("Running {} mode".format(execution_mode))
            self.sequential_flag = False

    def _stage_template(self, relative_template_path):
        try:
            if relative_template_path.lower().startswith('s3'):
                # Convert the S3 URL s3://bucket-name/object to HTTP URL https://s3.amazonaws.com/bucket-name/object
                s3_url = convert_s3_url_to_http_url(relative_template_path)
            else:
                local_file = os.path.join(self.manifest_folder, relative_template_path)
                remote_file = "{}/{}_{}".format(TEMPLATE_KEY_PREFIX, self.uuid, relative_template_path)
                logger.info("Uploading the template file: {} to S3 bucket: {} and key: {}".format(local_file,
                                                                                                  self.staging_bucket,
                                                                                                  remote_file))
                self.s3.upload_file(self.staging_bucket, local_file, remote_file)
                s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/', self.staging_bucket, '/', remote_file)
            return s3_url
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _load_params(self, relative_parameter_path, account=None, region=None):
        try:
            if relative_parameter_path.lower().startswith('s3'):
                parameter_file = download_remote_file(self.logger, relative_parameter_path)
            else:
                parameter_file = os.path.join(self.manifest_folder, relative_parameter_path)

            logger.info("Parsing the parameter file: {}".format(parameter_file))

            with open(parameter_file, 'r') as content_file:
                parameter_file_content = content_file.read()

            params = json.loads(parameter_file_content)
            if account is not None:
                # Deploying Core resource Stack Set
                # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
                sm_params = self.param_handler.update_params(params, account, region, False)
            else:
                # Deploying Baseline resource Stack Set
                sm_params = self.param_handler.update_params(params)

            logger.info("Input Parameters for State Machine: {}".format(sm_params))
            return sm_params
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_ssm_input_map(self, ssm_parameters):
        try:
            ssm_input_map = {}

            for ssm_parameter in ssm_parameters:
                key = ssm_parameter.name
                value = ssm_parameter.value
                ssm_value = self.param_handler.update_params(transform_params({key: value}))
                ssm_input_map.update(ssm_value)

            return ssm_input_map

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_state_machine_input_map(self, input_params, request_type='Create'):
        try:
            self.state_machine_event.update({'RequestType': request_type})
            self.state_machine_event.update({'ResourceProperties': input_params})

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_stack_set_state_machine_input_map(self, stack_set_name, template_url, parameters,
                                                  account_list, regions_list, ssm_map):
        input_params = {}
        input_params.update({'StackSetName': sanitize(stack_set_name)})
        input_params.update({'TemplateURL': template_url})
        input_params.update({'Parameters': parameters})
        input_params.update({'Capabilities': CAPABILITIES})

        if len(account_list) > 0:
            input_params.update({'AccountList': account_list})
            if len(regions_list) > 0:
                input_params.update({'RegionList': regions_list})
            else:
                input_params.update({'RegionList': [self.manifest.region]})
        else:
            input_params.update({'AccountList': ''})
            input_params.update({'RegionList': ''})

        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})

        self._create_state_machine_input_map(input_params)

    def _populate_ssm_params(self):
        try:
            # The scenario is if you have one core resource that exports output from CFN stack to SSM parameter
            # and then the next core resource reads the SSM parameter as input,
            # then it has to wait for the first core resource to
            # finish; read the SSM parameters and use its value as input for second core resource's input for SM
            # Get the parameters for CFN template from self.state_machine_event
            logger.debug("Populating SSM parameter values for SM input: {}".format(self.state_machine_event))
            params = self.state_machine_event.get('ResourceProperties').get('Parameters', {})
            # First transform it from {name: value} to [{'ParameterKey': name}, {'ParameterValue': value}]
            # then replace the SSM parameter names with its values
            sm_params = self.param_handler.update_params(transform_params(params))
            # Put it back into the self.state_machine_event
            self.state_machine_event.get('ResourceProperties').update({'Parameters': sm_params})
            logger.debug("Done populating SSM parameter values for SM input: {}".format(self.state_machine_event))

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _compare_template_and_params(self):
        try:
            stack_name = self.state_machine_event.get('ResourceProperties').get('StackSetName', '')
            flag = False
            if stack_name:
                stack_set = StackSet(self.logger)
                describe_response = stack_set.describe_stack_set(stack_name)
                if describe_response is not None:
                    self.logger.info("Found existing stack set.")

                    self.logger.info("Checking the status of last stack set operation on {}".format(stack_name))
                    response = stack_set.list_stack_set_operations(StackSetName=stack_name,
                                                                   MaxResults=1)

                    if response:
                        if response.get('Summaries'):
                            for instance in response.get('Summaries'):
                                self.logger.info("Status of last stack set operation : {}"
                                                 .format(instance.get('Status')))
                                if instance.get('Status') != 'SUCCEEDED':
                                    self.logger.info("The last stack operation did not succeed. "
                                                     "Triggering Update StackSet for {}".format(stack_name))
                                    return False

                    self.logger.info("Comparing the template of the StackSet: {} with local copy of template"
                                     .format(stack_name))

                    template_http_url = self.state_machine_event.get('ResourceProperties').get('TemplateURL', '')
                    if template_http_url:
                        template_s3_url = convert_http_url_to_s3_url(template_http_url)
                        local_template_file = download_remote_file(self.logger, template_s3_url)
                    else:
                        self.logger.error("TemplateURL in state machine input is empty. Check state_machine_event:{}"
                                          .format(self.state_machine_event))
                        return False

                    cfn_template_file = tempfile.mkstemp()[1]
                    with open(cfn_template_file, "w") as f:
                        f.write(describe_response.get('StackSet').get('TemplateBody'))

                    template_compare = filecmp.cmp(local_template_file, cfn_template_file, False)
                    self.logger.info("Comparing the parameters of the StackSet: {} "
                                     "with local copy of JSON parameters file".format(stack_name))

                    params_compare = True
                    params = self.state_machine_event.get('ResourceProperties').get('Parameters', {})
                    if template_compare:
                        cfn_params = reverse_transform_params(describe_response.get('StackSet').get('Parameters'))
                        for key, value in params.items():
                            if cfn_params.get(key, '') == value:
                                pass
                            else:
                                params_compare = False
                                break

                    self.logger.info("template_compare={}".format(template_compare))
                    self.logger.info("params_compare={}".format(params_compare))
                    if template_compare and params_compare:
                        account_list = self.state_machine_event.get('ResourceProperties').get("AccountList", [])
                        if account_list:
                            self.logger.info("Comparing the Stack Instances Account & Regions for StackSet: {}"
                                             .format(stack_name))
                            expected_region_list = set(self.state_machine_event.get('ResourceProperties').get("RegionList", []))

                            # iterator over accounts in event account list
                            for account in account_list:
                                actual_region_list = set()

                                self.logger.info("### Listing the Stack Instances for StackSet: {} and Account: {} ###"
                                                 .format(stack_name, account))
                                stack_instance_list = stack_set.list_stack_instances_per_account(stack_name, account)

                                self.logger.info(stack_instance_list)

                                if stack_instance_list:
                                    for instance in stack_instance_list:
                                        if instance.get('Status').upper() == 'CURRENT':
                                            actual_region_list.add(instance.get('Region'))
                                        else:
                                            self.logger.info("Found at least one of the Stack Instances in {} state."
                                                             " Triggering Update StackSet for {}"
                                                             .format(instance.get('Status'),
                                                                     stack_name))
                                            return False
                                else:
                                    self.logger.info("Found no stack instances in account: {}, "
                                                     "Updating StackSet: {}".format(account, stack_name))
                                    # # move the account id to index 0
                                    # newindex = 0
                                    # oldindex = self.state_machine_event.get('ResourceProperties').get("AccountList").index(account)
                                    # self.state_machine_event.get('ResourceProperties').get("AccountList").insert(newindex, self.state_machine_event.get('ResourceProperties').get("AccountList").pop(oldindex))
                                    return False

                                if expected_region_list.issubset(actual_region_list):
                                    self.logger.info("Found expected regions : {} in deployed stack instances : {},"
                                                     " so skipping Update StackSet for {}"
                                                     .format(expected_region_list,
                                                             actual_region_list,
                                                             stack_name))
                                    flag = True
                        else:
                            self.logger.info("Found no changes in template & parameters, "
                                             "so skipping Update StackSet for {}".format(stack_name))
                            flag = True
            return flag
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def state_machine_failed(self, status, failed_execution_list):
        error = " StackSet State Machine Execution(s) Failed. Navigate to the AWS Step Functions console and" \
                " review the following State Machine Executions. ARN List: {}".format(failed_execution_list)
        if status == 'FAILED':
            logger.error(100 * '*')
            logger.error(error)
            logger.error(100 * '*')
            sys.exit(1)

    def _run_or_queue_state_machine(self, stackset_name):
        try:
            logger.info("State machine Input: {}".format(self.state_machine_event))
            exec_name = "%s-%s-%s" % (self.state_machine_event.get('RequestType'), trim_length(stackset_name.replace(" ", ""), 50),
                                      time.strftime("%Y-%m-%dT%H-%M-%S"))
            # If Sequential, wait for the SM to be executed before kicking of the next one
            if self.sequential_flag:
                self.logger.info(" > > > > > >  Running Sequential Mode. > > > > > >")
                self._populate_ssm_params()
                if self._compare_template_and_params():
                    return
                else:
                    sm_exec_arn = self.state_machine.trigger_state_machine(self.sm_arn_stackset, self.state_machine_event, exec_name)
                    self.list_sm_exec_arns.append(sm_exec_arn)
                    status, failed_execution_list = self.monitor_state_machines_execution_status()
                    if status == 'FAILED':
                        self.state_machine_failed(status, failed_execution_list)
                    else:
                        self.logger.info("State Machine execution completed. Starting next execution...")
            # Else if Parallel, execute all SM at regular interval of wait_time
            else:
                self.logger.info(" | | | | | |  Running Parallel Mode. | | | | | |")
                # RUNS Parallel, execute all SM at regular interval of wait_time
                self._populate_ssm_params()
                # if the stackset comparision is matches - skip SM execution
                if self._compare_template_and_params():
                    return
                else: # if False execution SM
                    sm_exec_arn = self.state_machine.trigger_state_machine(self.sm_arn_stackset, self.state_machine_event, exec_name)
                time.sleep(int(wait_time))  # Sleeping for sometime
                self.list_sm_exec_arns.append(sm_exec_arn)
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _deploy_resource(self, resource, account_list):
        try:
            template_full_path = self._stage_template(resource.template_file)
            params = {}
            if resource.parameter_file:
                if len(resource.regions) > 0:
                    params = self._load_params(resource.parameter_file, account_list, resource.regions[0])
                else:
                    params = self._load_params(resource.parameter_file, account_list, self.manifest.region)

            ssm_map = self._create_ssm_input_map(resource.ssm_parameters)

            # Deploying Core resource Stack Set
            stack_name = "CustomControlTower-{}".format(resource.name)
            self._create_stack_set_state_machine_input_map(stack_name, template_full_path,
                                                                      params, account_list, resource.regions, ssm_map)


            self.logger.info(" >>> State Machine Input >>>")
            self.logger.info(self.state_machine_event)

            self._run_or_queue_state_machine(stack_name)
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _get_root_id(self, org):
        response = org.list_roots()
        self.logger.info("Response: List Roots")
        self.logger.info(response)
        return response['Roots'][0].get('Id')

    def _list_ou_for_parent(self, org, parent_id):
        _ou_list = org.list_organizational_units_for_parent(parent_id)
        self.logger.info("Print Organizational Units List under {}".format(parent_id))
        self.logger.info(_ou_list)
        return _ou_list

    def _get_accounts_in_ou(self, org, ou_id_list):
        _accounts_in_ou = []
        accounts_in_all_ous = []
        ou_id_to_account_map = {}

        for _ou_id in ou_id_list:
            _account_list = org.list_accounts_for_parent(_ou_id)
            for _account in _account_list:
                # filter ACTIVE and CREATED accounts
                if _account.get('Status') == "ACTIVE":
                    # create a list of accounts in OU
                    accounts_in_all_ous.append(_account.get('Id'))
                    _accounts_in_ou.append(_account.get('Id'))

            # create a map of accounts for each ou
            self.logger.info("Creating Key:Value Mapping - OU ID: {} ; Account List: {}"
                             .format(_ou_id, _accounts_in_ou))
            ou_id_to_account_map.update({_ou_id: _accounts_in_ou})
            self.logger.info(ou_id_to_account_map)

            # reset list of accounts in the OU
            _accounts_in_ou = []

        self.logger.info("All accounts in OU List: {}".format(accounts_in_all_ous))
        self.logger.info("OU to Account ID mapping")
        self.logger.info(ou_id_to_account_map)
        return accounts_in_all_ous, ou_id_to_account_map

    def _get_ou_ids(self, org):
        # for each OU get list of account
        # get root id
        root_id = self._get_root_id(org)

        # get OUs under the Org root
        ou_list_at_root_level = self._list_ou_for_parent(org, root_id)

        ou_id_list = []
        _ou_name_to_id_map = {}
        _all_ou_ids = []

        for ou_at_root_level in ou_list_at_root_level:
            # build list of all the OU IDs under Org root
            _all_ou_ids.append(ou_at_root_level.get('Id'))
            # build a list of ou id
            _ou_name_to_id_map.update({ou_at_root_level.get('Name'): ou_at_root_level.get('Id')})

        self.logger.info("Print OU Name to OU ID Map")
        self.logger.info(_ou_name_to_id_map)

        # return:
        # 1. OU IDs of the OUs in the manifest
        # 2. Account IDs in OUs in the manifest
        # 3. Account IDs in all the OUs in the manifest
        return _all_ou_ids, _ou_name_to_id_map

    def get_account_for_name(self, org):
        # get all accounts in the organization
        account_list = org.get_accounts_in_org()
        #self.logger.info("Print Account List: {}".format(account_list))

        _name_to_account_map = {}
        for account in account_list:
            if account.get("Status") == "ACTIVE":
                _name_to_account_map.update({account.get("Name"): account.get("Id")})

        self.logger.info("Print Account Name > Account Mapping")
        self.logger.info(_name_to_account_map)

        return _name_to_account_map

    def get_organization_details(self):
        # > build dict
        # KEY: OU Name (in the manifest)
        # VALUE: OU ID (at root level)
        # > build list
        # all OU IDs under root
        org = Organizations(self.logger)
        all_ou_ids, ou_name_to_id_map = self._get_ou_ids(org)
        # > build list of all active accounts
        # use case: use to validate accounts in the manifest file.
        # > build dict
        # KEY: OU ID (for each OU at root level)
        # VALUE: get list of all active accounts
        # use case: map OU Name to account IDs
        accounts_in_all_ous, ou_id_to_account_map = self._get_accounts_in_ou(org, all_ou_ids)
        # build dict
        # KEY: email
        # VALUE: account id
        # use case: convert email in manifest to account ID for SM event
        name_to_account_map = self.get_account_for_name(org)
        return accounts_in_all_ous, ou_id_to_account_map, ou_name_to_id_map, name_to_account_map

    def start_stackset_sm(self):
        try:
            logger.info("Parsing Core Resources from {} file".format(self.manifest_file_path))
            count = 0

            accounts_in_all_ous, ou_id_to_account_map, ou_name_to_id_map, name_to_account_map = self.get_organization_details()

            for resource in self.manifest.cloudformation_resources:
                self.logger.info(">>>>>>>>> START : {} >>>>>>>>>".format(resource.name))
                # Handle scenario if 'deploy_to_ou' key does not exist in the resource
                try:
                    self.logger.info(resource.deploy_to_ou)
                except:
                    resource.deploy_to_ou = []

                # Handle scenario if 'deploy_to_account' key does not exist in the resource
                try:
                    self.logger.info(resource.deploy_to_account)
                except:
                    resource.deploy_to_account = []

                # find accounts for given ou name
                accounts_in_ou = []
                ou_ids_manifest = []

                # check if OU name list is empty
                if resource.deploy_to_ou:
                    # convert OU Name to OU IDs
                    for ou_name in resource.deploy_to_ou:
                        ou_id = [value for key, value in ou_name_to_id_map.items() if ou_name.lower() in key.lower()]
                        ou_ids_manifest.extend(ou_id)

                    # convert OU IDs to accounts
                    for ou_id, accounts in ou_id_to_account_map.items():
                        if ou_id in ou_ids_manifest:
                            accounts_in_ou.extend(accounts)

                    self.logger.info(">>> Accounts: {} in OUs: {}".format(accounts_in_ou, resource.deploy_to_ou))

                # convert account numbers to string type
                account_list = self._convert_list_values_to_string(resource.deploy_to_account)
                self.logger.info(">>>>>> ACCOUNT LIST")
                self.logger.info(account_list)

                # separate account id and emails
                name_list = []
                new_account_list = []
                self.logger.info(account_list)
                for item in account_list:
                    if item.isdigit() and len(item) == 12:  # if an actual account ID
                        new_account_list.append(item)
                        self.logger.info(new_account_list)
                    else:
                        name_list.append(item)
                        self.logger.info(name_list)

                # check if name list is empty
                if name_list:
                    # convert OU Name to OU IDs
                    for name in name_list:
                        name_account = [value for key, value in name_to_account_map.items() if
                                         name.lower() in key.lower()]
                        self.logger.info("%%%%%%% Name {} -  Account {}".format(name, name_account))
                        new_account_list.extend(name_account)

                # Remove account ids from the manifest that is not in the organization or not active
                sanitized_account_list = list(set(new_account_list).intersection(set(accounts_in_all_ous)))
                self.logger.info("Print Updated Manifest Account List")
                self.logger.info(sanitized_account_list)

                # merge account lists manifest account list and accounts under OUs in the manifest
                sanitized_account_list.extend(accounts_in_ou)
                sanitized_account_list = list(set(sanitized_account_list)) # remove duplicate accounts
                self.logger.info("Print merged account list - accounts in manifest + account under OU in manifest")
                self.logger.info(sanitized_account_list)

                if resource.deploy_method.lower() == 'stack_set':
                    self._deploy_resource(resource, sanitized_account_list)
                else:
                    raise Exception("Unsupported deploy_method: {} found for resource {} and Account: {} in Manifest"
                                    .format(resource.deploy_method, resource.name, sanitized_account_list))
                self.logger.info("<<<<<<<<< FINISH : {} <<<<<<<<<".format(resource.name))

                # Count number of stack sets deployed
                count += 1
            data = {"StackSetCount": str(count)}
            self.send.metrics(data)

            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    # return list of strings
    def _convert_list_values_to_string(self, _list):
        return list(map(str, _list))

    # monitor list of state machine executions
    def monitor_state_machines_execution_status(self):
        try:
            if self.list_sm_exec_arns:
                self.logger.info("Starting to monitor the SM Executions: {}".format(self.list_sm_exec_arns))
                final_status = 'RUNNING'

                while final_status == 'RUNNING':
                    for sm_exec_arn in self.list_sm_exec_arns:
                        status = self.state_machine.check_state_machine_status(sm_exec_arn)
                        if status == 'RUNNING':
                            final_status = 'RUNNING'
                            time.sleep(int(wait_time))
                            break
                        else:
                            final_status = 'COMPLETED'

                err_flag = False
                failed_sm_execution_list = []
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(sm_exec_arn)
                    if status == 'SUCCEEDED':
                        continue
                    else:
                        failed_sm_execution_list.append(sm_exec_arn)
                        err_flag = True
                        continue

                if err_flag:
                    return 'FAILED', failed_sm_execution_list
                else:
                    return 'SUCCEEDED', ''
            else:
                self.logger.info("SM Execution List {} is empty, nothing to monitor.".format(self.list_sm_exec_arns))
                return None, []

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def trigger_stackset_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.start_stackset_sm()
            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise
class LaunchSCP(object):
    def __init__(self, logger, wait_time, manifest_file_path, sm_arn_scp,
                 staging_bucket):
        self.state_machine = StateMachine(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_scp = sm_arn_scp
        self.manifest = None
        self.list_sm_exec_arns = []
        self.nested_ou_delimiter = ""
        self.staging_bucket = staging_bucket
        self.root_id = None

    def _create_service_control_policy_state_machine_input_map(
            self, policy_name, policy_full_path, policy_desc='', ou_list=[]):
        input_params = {}
        policy_doc = {}
        policy_doc.update({'Name': sanitize(policy_name)})
        policy_doc.update({'Description': policy_desc})
        policy_doc.update({'PolicyURL': policy_full_path})
        input_params.update({'PolicyDocument': policy_doc})
        input_params.update({'AccountId': ''})
        input_params.update({'PolicyList': []})
        input_params.update({'Operation': ''})
        input_params.update({'OUList': ou_list})
        input_params.update({'OUNameDelimiter': self.nested_ou_delimiter})
        return self._create_state_machine_input_map(input_params)

    def _create_state_machine_input_map(self,
                                        input_params,
                                        request_type='Create'):
        request = {}
        request.update({'RequestType': request_type})
        request.update({'ResourceProperties': input_params})

        return request

    def _stage_template(self, relative_template_path):
        if relative_template_path.lower().startswith('s3'):
            # Convert the S3 URL s3://bucket-name/object
            # to HTTP URL https://s3.amazonaws.com/bucket-name/object
            s3_url = convert_s3_url_to_http_url(relative_template_path)
        else:
            local_file = os.path.join(self.manifest_folder,
                                      relative_template_path)
            # remote_file = "{}/{}_{}".format(TEMPLATE_KEY_PREFIX, self.token, relative_template_path[relative_template_path.rfind('/')+1:])
            remote_file = "{}/{}".format(TEMPLATE_KEY_PREFIX,
                                         relative_template_path)
            logger.info(
                "Uploading the template file: {} to S3 bucket: {} and key: {}".
                format(local_file, self.staging_bucket, remote_file))
            self.s3.upload_file(self.staging_bucket, local_file, remote_file)
            s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/',
                                       self.staging_bucket, '/', remote_file)
        return s3_url

    def _run_or_queue_state_machine(self, sm_input, sm_arn, sm_name):
        logger.info("State machine Input: {}".format(sm_input))
        exec_name = "%s-%s-%s" % (sm_input.get('RequestType'),
                                  trim_length(sm_name.replace(" ", ""), 50),
                                  time.strftime("%Y-%m-%dT%H-%M-%S"))

        # execute all SM at regular interval of wait_time
        sm_exec_arn = self.state_machine.trigger_state_machine(
            sm_arn, sm_input, exec_name)
        time.sleep(int(wait_time))  # Sleeping for sometime
        self.list_sm_exec_arns.append(sm_exec_arn)

    def trigger_service_control_policy_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.start_service_control_policy_sm()
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def monitor_state_machines_execution_status(self):
        try:
            if self.list_sm_exec_arns:
                final_status = 'RUNNING'

                while final_status == 'RUNNING':
                    for sm_exec_arn in self.list_sm_exec_arns:
                        status = self.state_machine.check_state_machine_status(
                            sm_exec_arn)
                        if status == 'RUNNING':
                            final_status = 'RUNNING'
                            time.sleep(int(wait_time))
                            break
                        else:
                            final_status = 'COMPLETED'

                err_flag = False
                failed_sm_execution_list = []
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(
                        sm_exec_arn)
                    if status == 'SUCCEEDED':
                        continue
                    else:
                        failed_sm_execution_list.append(sm_exec_arn)
                        err_flag = True
                        continue

                if err_flag:
                    return 'FAILED', failed_sm_execution_list
                else:
                    return 'SUCCEEDED', ''
            else:
                self.logger.info(
                    "SM Execution List {} is empty, nothing to monitor.".
                    format(self.list_sm_exec_arns))
                return None, []
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_control_policy_sm(self):
        try:
            logger.info("Processing SCPs from {} file".format(
                self.manifest_file_path))
            count = 0

            for policy in self.manifest.organization_policies:
                # Generate the list of OUs to attach this SCP to
                ou_list = []
                attach_ou_list = set(policy.apply_to_accounts_in_ou)

                for ou in attach_ou_list:
                    ou_list.append((ou, 'Attach'))

                policy_full_path = self._stage_template(policy.policy_file)
                sm_input = self._create_service_control_policy_state_machine_input_map(
                    policy.name, policy_full_path, policy.description, ou_list)
                self._run_or_queue_state_machine(sm_input, sm_arn_scp,
                                                 policy.name)

                # Count number of SCPs
                count += 1

            data = {"SCPPolicyCount": str(count)}
            self.send.metrics(data)

            # Exit where there are no organization policies
            if count == 0:
                logger.info("No organization policies are found.")
                sys.exit(0)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
from states import DefaultState, MenuState, ChooseColorState, ConfirmColorState, picture_interceptor, \
    SentPictureState, GenericState, AlwaysPoppingState, TranslateLanguageState, ChooseLanguageState


class BotUserPersistenceStrategy(PersistenceStrategy):
    def get_or_create(self, entity_name):
        user = BotUser.get_by_id(entity_name)
        return user or BotUser(id=entity_name)

    def get_states(self, entity):
        return entity.states

    def set_states(self, entity, states):
        entity.states = states

    def put(self, entity_name, entity):
        entity.put()

state_machine = StateMachine(DefaultState, BotUserPersistenceStrategy())

state_machine.register_global_interceptor(picture_interceptor)

state_machine.register_state(MenuState)
state_machine.register_state(ChooseColorState)
state_machine.register_state(ConfirmColorState)
state_machine.register_state(SentPictureState)
state_machine.register_state(GenericState)
state_machine.register_state(AlwaysPoppingState)
state_machine.register_state(ChooseLanguageState)
state_machine.register_state(TranslateLanguageState)
예제 #17
0
class LaunchAVM(object):
    def __init__(self, logger, wait_time, manifest_file_path, sm_arn_launch_avm, batch_size):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_launch_avm = sm_arn_launch_avm
        self.manifest = None
        self.list_sm_exec_arns = []
        self.batch_size = batch_size
        self.avm_product_name = None
        self.avm_portfolio_name = None
        self.avm_params = None
        self.root_id = None

    def _load_params(self, relative_parameter_path, account = None, region = None):
        parameter_file = os.path.join(self.manifest_folder, relative_parameter_path)

        self.logger.info("Parsing the parameter file: {}".format(parameter_file))

        with open(parameter_file, 'r') as content_file:
            parameter_file_content = content_file.read()

        params = json.loads(parameter_file_content)
        # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
        sm_params = self.param_handler.update_params(params, account, region, False)

        self.logger.info("Input Parameters for State Machine: {}".format(sm_params))
        return sm_params

    def _create_state_machine_input_map(self, input_params, request_type='Create'):
        request = {}
        request.update({'RequestType':request_type})
        request.update({'ResourceProperties':input_params})
        return request

    def _create_launch_avm_state_machine_input_map(self, portfolio, product, accounts):
        input_params = {}
        input_params.update({'PortfolioName': sanitize(portfolio, True)})
        input_params.update({'ProductName': sanitize(product, True)})
        input_params.update({'ProvisioningParametersList': accounts})
        return self._create_state_machine_input_map(input_params)

    def _process_accounts_in_batches(self, accounts, organizations, ou_id, ou_name):
        try:
            list_of_accounts = []
            for account in accounts:
                if account.get('Status').upper() == 'SUSPENDED':
                    organizations.move_account(account.get('Id'), ou_id, self.root_id)
                    continue
                else:
                    params = self.avm_params.copy()
                    for key, value in params.items():
                        if value.lower() == 'accountemail':
                            params.update({key: account.get('Email')})
                        elif value.lower() == 'accountname':
                            params.update({key: account.get('Name')})
                        elif value.lower() == 'orgunitname':
                            params.update({key: ou_name})

                    self.logger.info(
                        "Input parameters format for Account: {} are {}".format(account.get('Name'), params))

                    list_of_accounts.append(params)

            if len(list_of_accounts) > 0:
                sm_input = self._create_launch_avm_state_machine_input_map(self.avm_portfolio_name,
                                                                           self.avm_product_name.strip(),
                                                                           list_of_accounts)
                self.logger.info("Launch AVM state machine Input: {}".format(sm_input))
                exec_name = "%s-%s-%s-%s" % (sm_input.get('RequestType'), sanitize(ou_name), "Launch-AVM",
                                             time.strftime("%Y-%m-%dT%H-%M-%S"))
                sm_exec_arn = self.state_machine.trigger_state_machine(self.sm_arn_launch_avm, sm_input, exec_name)
                self.list_sm_exec_arns.append(sm_exec_arn)

                time.sleep(int(wait_time))  # Sleeping for sometime
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def start_launch_avm(self):
        try:
            self.logger.info("Starting the launch AVM trigger")

            org = Org({}, self.logger)
            organizations = Organizations(self.logger)
            delimiter = self.manifest.nested_ou_delimiter

            response = organizations.list_roots()
            self.logger.info("List roots Response")
            self.logger.info(response)
            self.root_id = response['Roots'][0].get('Id')

            for ou in self.manifest.organizational_units:
                self.avm_product_name = ou.include_in_baseline_products[0]

                # Find the AVM for this OU and get the AVM parameters
                for portfolio in self.manifest.portfolios:
                    for product in portfolio.products:
                        if product.name.strip() == self.avm_product_name.strip():
                            self.avm_params = self._load_params(product.parameter_file)
                            self.avm_portfolio_name = portfolio.name.strip()

                if len(self.avm_params) == 0:
                    raise Exception("Baseline product: {} for OU: {} is not found in the" \
                      " portfolios section of Manifest".format(self.avm_product, ou.name))

                ou_id = org._get_ou_id(organizations, self.root_id, ou.name, delimiter)

                self.logger.info("Processing Accounts under: {} in batches of size: {}".format(ou_id, self.batch_size))
                response = organizations.list_accounts_for_parent(ou_id, self.batch_size)
                self.logger.info("List Accounts for Parent OU {} Response".format(ou_id))
                self.logger.info(response)
                self._process_accounts_in_batches(response.get('Accounts'), organizations, ou_id, ou.name)
                next_token = response.get('NextToken', None)

                while next_token is not None:
                    self.logger.info("Next Token Returned: {}".format(next_token))
                    response = organizations.list_accounts_for_parent(ou_id, self.batch_size, next_token)
                    self.logger.info("List Accounts for Parent OU {} Response".format(ou_id))
                    self.logger.info(response)
                    self._process_accounts_in_batches(response.get('Accounts'), organizations, ou_id, ou.name)
                    next_token = response.get('NextToken', None)

            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def trigger_launch_avm_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.start_launch_avm()
            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def monitor_state_machines_execution_status(self):
        try:
            final_status = 'RUNNING'

            while final_status == 'RUNNING':
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(sm_exec_arn)
                    if status == 'RUNNING':
                        final_status = 'RUNNING'
                        time.sleep(int(wait_time))
                        break
                    else:
                        final_status = 'COMPLETED'

            err_flag = False
            failed_sm_execution_list = []
            for sm_exec_arn in self.list_sm_exec_arns:
                status = self.state_machine.check_state_machine_status(sm_exec_arn)
                if status == 'SUCCEEDED':
                    continue
                else:
                    failed_sm_execution_list.append(sm_exec_arn)
                    err_flag = True
                    continue

            if err_flag:
                return 'FAILED', failed_sm_execution_list
            else:
                return 'SUCCEEDED', ''

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise
예제 #18
0
class LaunchAVM(object):
    def __init__(self, logger, wait_time, manifest_file_path,
                 sm_arn_launch_avm, batch_size):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.sc = SC(logger)
        self.param_handler = CFNParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_launch_avm = sm_arn_launch_avm
        self.manifest = None
        self.list_sm_exec_arns = []
        self.batch_size = batch_size
        self.avm_product_name = None
        self.avm_product_id = None
        self.avm_artifact_id = None
        self.avm_params = None
        self.root_id = None
        self.sc_portfolios = {}
        self.sc_products = {}
        self.provisioned_products = {}  # [productid] = []
        self.provisioned_products_by_account = {
        }  # [account] = [] list of ppids

    def _load_params(self, relative_parameter_path, account=None, region=None):
        parameter_file = os.path.join(self.manifest_folder,
                                      relative_parameter_path)

        self.logger.info(
            "Parsing the parameter file: {}".format(parameter_file))

        with open(parameter_file, 'r') as content_file:
            parameter_file_content = content_file.read()

        params = json.loads(parameter_file_content)
        # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
        sm_params = self.param_handler.update_params(params, account, region,
                                                     False)

        self.logger.info(
            "Input Parameters for State Machine: {}".format(sm_params))
        return sm_params

    def _create_launch_avm_state_machine_input_map(self, accounts):
        """
        Create the input parameters for the state machine
        """
        portfolio = self.avm_portfolio_name
        product = self.avm_product_name.strip()

        request = {}
        request.update({'RequestType': 'Create'})
        request.update({'PortfolioId': self.sc_portfolios.get(portfolio)})

        portfolio_exist = False
        if any(self.sc_portfolios.get(portfolio)):
            portfolio_exist = True
        request.update({'PortfolioExist': portfolio_exist})

        request.update(
            {'ProductId': self.sc_products.get(portfolio).get(product)})
        request.update({
            'ProvisioningArtifactId':
            self._get_provisioning_artifact_id(request.get('ProductId'))
        })

        product_exist = False
        if any(self.sc_products.get(portfolio).get(product)):
            product_exist = True
        request.update({'ProductExist': product_exist})

        input_params = {}
        input_params.update({'PortfolioName': sanitize(portfolio, True)})
        input_params.update({'ProductName': sanitize(product, True)})
        input_params.update({'ProvisioningParametersList': accounts})

        request.update({'ResourceProperties': input_params})
        # Set up the iteration parameters for the state machine
        request.update({'Index': 0})
        request.update({'Step': 1})
        request.update(
            {'Count': len(input_params['ProvisioningParametersList'])})

        return request

    def _get_provisioning_artifact_id(self, product_id):
        self.logger.info("Listing the provisioning artifact")
        response = self.sc.list_provisioning_artifacts(product_id)
        self.logger.info("List Artifacts Response")
        self.logger.info(response)

        version_list = response.get('ProvisioningArtifactDetails')
        if version_list:
            return version_list[-1].get('Id')
        else:
            raise Exception("Unable to find provisioning artifact id.")

    def _portfolio_in_manifest(self, portname):
        """
        Scan the list of portfolios in the manifest looking for a match
        to portname
        """
        portname = portname.strip()
        self.logger.debug('Looking for portfolio {}'.format(portname))
        exists = False
        for port in self.manifest.portfolios:
            if portname == port.name.strip():
                exists = True
                break
        return exists

    def _product_in_manifest(self, portname, productname):
        """
        Scan the list of products in the portfolio in the manifest looking
        for a match to product name
        """
        portname = portname.strip()
        productname = productname.strip()
        self.logger.debug('Looking for product {} in portfolio {}'.format(
            productname, portname))
        exists = False
        for port in self.manifest.portfolios:
            if portname == port.name.strip():
                for product in port.products:
                    if productname == product.name.strip():
                        self.logger.debug('MATCH')
                        exists = True
                        break
                break
        return exists

    def sc_lookup(self):
        """
        Using data from input_params gather ServiceCatalog product info.
        The product data is used when creating the json data to hand off
        to LaunchAVM state machine
        """
        try:

            response = self.sc.list_portfolios()
            portfolio_list = response.get('PortfolioDetails')

            for portfolio in portfolio_list:
                portfolio_name = portfolio.get('DisplayName')

                # Is this portfolio in the manifest? If not skip it.
                if not self._portfolio_in_manifest(portfolio_name):
                    continue

                portfolio_id = portfolio.get('Id')
                self.sc_portfolios.update({portfolio_name: portfolio_id})

                # Initialize the portfolio in the products dictionary
                self.sc_products.update({portfolio_name: {}})

                # Get info for the products in this portfolio
                response = self.sc.search_products_as_admin(portfolio_id)

                product_list = response.get('ProductViewDetails')

                # find the product in the portfolio and add it to the dictionary
                for product in product_list:
                    portfolio_product_name = product.get(
                        'ProductViewSummary').get('Name')
                    if not self._product_in_manifest(portfolio_name,
                                                     portfolio_product_name):
                        continue

                    product_id = product['ProductViewSummary'].get('ProductId')

                    # add the product to the sc_products dictionary
                    self.sc_products[portfolio_name].update(
                        {portfolio_product_name: product_id})

            self.logger.debug('DUMP OF SC_PRODUCTS')
            self.logger.debug(json.dumps(self.sc_products, indent=2))

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'CLASS': self.__class__.__name__,
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def get_indexed_provisioned_products(self):
        """
        Get all provisioned products for a Service Catalog product Id.
        Create an index by provisioned_product_id.

        This data is the same for every account for the same product Id.

        Ref: state_machine_handler::search_provisioned_products@2031
        """
        try:
            # 1) Get a complete list of provisioned products for this product Id

            pprods = self.sc.search_provisioned_products(self.avm_product_id)
            token = 'init'

            while token:

                for provisioned_product in pprods.get('ProvisionedProducts'):
                    self.logger.info('PROCESSING ' +
                                     str(provisioned_product['Id']))
                    self.logger.debug(
                        "ProvisionedProduct:{}".format(provisioned_product))
                    provisioned_product_id = provisioned_product.get('Id')

                    # 2) Remove any with a status of ERROR or UNDER_CHANGE
                    # Ignore products that error out before and
                    # to avoid the case of looking up the same product ignore UNDER_CHANGE
                    if provisioned_product.get(
                            'Status') == 'ERROR' or provisioned_product.get(
                                'Status') == 'UNDER_CHANGE':
                        continue

                    # This provisioned product passes - add it to the dict
                    # We only reference AccountEmail and ExistingParameterKeys in StackInfo
                    self.provisioned_products[provisioned_product_id] = {}

                    # 3) Extract stack_name from stack_id (see state_machine_handler@2066)
                    stack_id = provisioned_product.get('PhysicalId')
                    self.logger.debug("stack_id={}".format(stack_id))

                    # Extract Stack Name from the Physical Id
                    # e.g. Stack Id: arn:aws:cloudformation:${AWS::Region}:${AWS::AccountId}:stack/SC-${AWS::AccountId}-pp-fb3xte4fc4jmk/5790fb30-547b-11e8-b302-50fae98974c5
                    # Stack name = SC-${AWS::AccountId}-pp-fb3xte4fc4jmk
                    stack_name = stack_id.split('/')[1]
                    self.logger.debug("stack_name={}".format(stack_name))

                    # 4) Query stack state and add AccountEmail, ExistingParameterKeys (see shm@2097)
                    self.provisioned_products[
                        provisioned_product_id] = get_stack_data(
                            stack_name, self.logger)

                    # Add the provisioned product Id to key/value keyed by account
                    # Note: by intentional limitation there is exactly one provisioned product
                    #   per Product Id in ALZ
                    account_email = self.provisioned_products[
                        provisioned_product_id].get('AccountEmail', None)
                    if account_email:
                        self.provisioned_products_by_account[
                            account_email] = provisioned_product_id

                token = pprods.get('NextPageToken', None)
                pprods = None  # Reset
                if token:
                    pprods = self.sc.search_provisioned_products(
                        self.avm_product_id, token)

            self.logger.debug('DUMP OF PROVISIONED PRODUCTS')
            self.logger.debug(
                json.dumps(self.provisioned_products,
                           indent=2,
                           default=date_handler))

            self.logger.debug('DUMP OF PROVISIONED PRODUCTS INDEX')
            self.logger.debug(
                json.dumps(self.provisioned_products_by_account,
                           indent=2,
                           default=date_handler))

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)

    def _process_accounts_in_batches(self, accounts, organizations, ou_id,
                                     ou_name):
        """
        Each account in an OU is processed into a batch of one or more accounts.
        This function processes one batch.

        For each account:
            get email, name, ou name
            ignore suspended accounts
            build state machine input
            instantiate state machine

        Note: sm_input must not exceed 32K max
        """
        try:
            list_of_accounts = []
            for account in accounts:
                # Process each account
                if account.get('Status').upper() == 'SUSPENDED':
                    # Account is suspended
                    organizations.move_account(account.get('Id'), ou_id,
                                               self.root_id)
                    continue
                else:
                    # Active account
                    params = self.avm_params.copy()
                    for key, value in params.items():
                        if value.lower() == 'accountemail':
                            params.update({key: account.get('Email')})
                        elif value.lower() == 'accountname':
                            params.update({key: account.get('Name')})
                        elif value.lower() == 'orgunitname':
                            params.update({key: ou_name})

                    # Retrieve the provisioned product id
                    ppid = self.provisioned_products_by_account.get(
                        account.get('Email'), None)
                    if ppid:
                        params.update({'ProvisionedProductId': ppid})
                        params.update({'ProvisionedProductExists': True})
                        params.update({
                            'ExistingParameterKeys':
                            self.provisioned_products.get(ppid).get(
                                'ExistingParameterKeys', [])
                        })
                    else:
                        params.update({'ProvisionedProductExists': False})

                    self.logger.info(
                        "Input parameters format for Account: {} are {}".
                        format(account.get('Name'), params))

                    list_of_accounts.append(params)

            if list_of_accounts:
                # list_of_accounts is passed directly through to the input json data
                # This data should be complete from start_launch_avm
                sm_input = self._create_launch_avm_state_machine_input_map(
                    # self.avm_portfolio_name,
                    # self.avm_product_name.strip(),
                    list_of_accounts)
                self.logger.info(
                    "Launch AVM state machine Input: {}".format(sm_input))
                exec_name = "%s-%s-%s-%s-%s" % (
                    "AVM",
                    sanitize(ou_name[:40]),
                    time.strftime("%Y-%m-%dT%H-%M-%S"),
                    str(time.time()).split('.')[1],  # append microsecond
                    str(uuid4()).split('-')[1])
                sm_exec_arn = self.state_machine.trigger_state_machine(
                    self.sm_arn_launch_avm, sm_input, exec_name)
                self.list_sm_exec_arns.append(sm_exec_arn)

                time.sleep(int(self.wait_time))  # Sleeping for sometime

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_launch_avm(self):
        """
        Get a list of accounts
        Find the portfolio id and product id for the AVM product
        Call _process_accounts_in_batches to build and submit input data for
        each batch to a state machine instance
        """
        try:
            self.logger.info("Starting the launch AVM trigger")

            org = Org({}, self.logger)
            organizations = Organizations(self.logger)
            delimiter = ':'
            if self.manifest.nested_ou_delimiter:
                delimiter = self.manifest.nested_ou_delimiter

            response = organizations.list_roots()
            self.logger.debug("List roots Response")
            self.logger.debug(response)
            self.root_id = response['Roots'][0].get('Id')

            for ou in self.manifest.organizational_units:
                self.avm_product_name = ou.include_in_baseline_products[0]

                # Find the AVM for this OU and get the AVM parameters
                for portfolio in self.manifest.portfolios:
                    for product in portfolio.products:
                        if product.name.strip() == self.avm_product_name.strip(
                        ):
                            self.avm_params = self._load_params(
                                product.parameter_file)
                            self.avm_portfolio_name = portfolio.name.strip()
                            self.avm_product_id = self.sc_products.get(
                                portfolio.name.strip()).get(
                                    product.name.strip())
                """
                Get provisioned product data for all accounts
                Note: this reduces the number of API calls, but produces a large
                in-memory dictionary. However, even at 1,000 accounts this should
                not be a concern
                Populates:
                self.provisioned_products = {}              # [productid] = []
                self.provisioned_products_by_account = {}   # [account] = [] list of ppids
                self.stacks = {}                            # [stackname] = stackinfo
                """
                self.get_indexed_provisioned_products()

                if not self.avm_params:
                    raise Exception("Baseline product: {} for OU: {} is not found in the" \
                      " portfolios section of Manifest".format(self.avm_product_name, ou.name))

                ou_id = org._get_ou_id(organizations, self.root_id, ou.name,
                                       delimiter)

                self.logger.info(
                    "Processing Accounts under: {} in batches of size: {}".
                    format(ou_id, self.batch_size))
                response = organizations.list_accounts_for_parent(
                    ou_id, self.batch_size)
                self.logger.info(
                    "List Accounts for Parent OU {} Response".format(ou_id))
                self.logger.info(response)
                self._process_accounts_in_batches(response.get('Accounts'),
                                                  organizations, ou_id,
                                                  ou.name)
                next_token = response.get('NextToken', None)

                while next_token is not None:
                    self.logger.info(
                        "Next Token Returned: {}".format(next_token))
                    response = organizations.list_accounts_for_parent(
                        ou_id, self.batch_size, next_token)
                    self.logger.info(
                        "List Accounts for Parent OU {} Response".format(
                            ou_id))
                    self.logger.info(response)
                    self._process_accounts_in_batches(response.get('Accounts'),
                                                      organizations, ou_id,
                                                      ou.name)
                    next_token = response.get('NextToken', None)

            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def trigger_launch_avm_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.sc_lookup()  # Get Service Catalog data
            self.start_launch_avm()
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def monitor_state_machines_execution_status(self):
        try:
            final_status = 'RUNNING'

            while final_status == 'RUNNING':
                for sm_exec_arn in self.list_sm_exec_arns:
                    if self.state_machine.check_state_machine_status(
                            sm_exec_arn) == 'RUNNING':
                        final_status = 'RUNNING'
                        time.sleep(int(wait_time))
                        break
                    else:
                        final_status = 'COMPLETED'

            err_flag = False
            failed_sm_execution_list = []
            for sm_exec_arn in self.list_sm_exec_arns:
                if self.state_machine.check_state_machine_status(
                        sm_exec_arn) == 'SUCCEEDED':
                    continue
                else:
                    failed_sm_execution_list.append(sm_exec_arn)
                    err_flag = True
                    continue

            if err_flag:
                result = ['FAILED', failed_sm_execution_list]
            else:
                result = ['SUCCEEDED', '']

            return result

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
예제 #19
0
class SMExecutionManager:
    def __init__(self, logger, sm_input_list):
        self.logger = logger
        self.sm_input_list = sm_input_list
        self.list_sm_exec_arns = []
        self.stack_set_exist = True
        self.solution_metrics = Metrics(logger)
        self.param_handler = CFNParamsHandler(logger)
        self.state_machine = StateMachine(logger)
        self.stack_set = StackSet(logger)
        self.s3 = S3(logger)
        self.wait_time = os.environ.get('WAIT_TIME')
        self.execution_mode = os.environ.get('EXECUTION_MODE')

    def launch_executions(self):
        self.logger.info("%%% Launching State Machine Execution %%%")
        if self.execution_mode.upper() == 'PARALLEL':
            self.logger.info(" | | | | |  Running Parallel Mode. | | | | |")
            return self.run_execution_parallel_mode()

        elif self.execution_mode.upper() == 'SEQUENTIAL':
            self.logger.info(" > > > > >  Running Sequential Mode. > > > > >")
            return self.run_execution_sequential_mode()
        else:
            raise Exception("Invalid execution mode: {}".format(
                self.execution_mode))

    def run_execution_sequential_mode(self):
        status, failed_execution_list = None, []
        # start executions at given intervals
        for sm_input in self.sm_input_list:
            updated_sm_input = self.populate_ssm_params(sm_input)
            stack_set_name = sm_input.get('ResourceProperties')\
                .get('StackSetName', '')

            template_matched, parameters_matched = \
                self.compare_template_and_params(sm_input, stack_set_name)

            self.logger.info("Stack Set Name: {} | "
                             "Same Template?: {} | "
                             "Same Parameters?: {}".format(
                                 stack_set_name, template_matched,
                                 parameters_matched))
            if template_matched and parameters_matched and self.stack_set_exist:
                start_execution_flag = not (
                    self.check_stack_instances_per_account(
                        sm_input, stack_set_name))
            else:
                # the template or parameters needs to be updated - start SM exeution
                start_execution_flag = True

            if start_execution_flag:
                sm_exec_name = self.get_sm_exec_name(updated_sm_input)

                sm_exec_arn = self.setup_execution(updated_sm_input,
                                                   sm_exec_name)
                self.list_sm_exec_arns.append(sm_exec_arn)

                status, failed_execution_list = \
                    self.monitor_state_machines_execution_status()
                if status == 'FAILED':
                    return status, failed_execution_list
                else:
                    self.logger.info("State Machine execution completed. "
                                     "Starting next execution...")
        else:
            self.logger.info("All State Machine executions completed.")
        return status, failed_execution_list

    def run_execution_parallel_mode(self):
        # start executions at given intervals
        for sm_input in self.sm_input_list:
            sm_exec_name = self.get_sm_exec_name(sm_input)
            sm_exec_arn = self.setup_execution(sm_input, sm_exec_name)
            self.list_sm_exec_arns.append(sm_exec_arn)
            time.sleep(int(self.wait_time))
        # monitor execution status
        status, failed_execution_list = \
            self.monitor_state_machines_execution_status()
        return status, failed_execution_list

    @staticmethod
    def get_sm_exec_name(sm_input):
        if os.environ.get('STAGE_NAME').upper() == 'BASELINERESOURCES':
            return sm_input.get('ResourceProperties').get('StackSetName')
        else:
            return str(uuid4())  # return random string

    def setup_execution(self, sm_input, name):
        self.logger.info("State machine Input: {}".format(sm_input))

        # set execution name
        exec_name = "%s-%s-%s" % (sm_input.get('RequestType'),
                                  trim_length_from_end(
                                      name.replace(" ", ""),
                                      50), time.strftime("%Y-%m-%dT%H-%M-%S"))

        # execute all SM at regular interval of wait_time
        return self.state_machine.trigger_state_machine(
            os.environ.get('SM_ARN'), sm_input, exec_name)

    def populate_ssm_params(self, sm_input):
        """The scenario is if you have one CFN resource that exports output
         from CFN stack to SSM parameter and then the next CFN resource
         reads the SSM parameter as input, then it has to wait for the first
         CFN resource to finish; read the SSM parameters and use its value
         as input for second CFN resource's input for SM. Get the parameters
         for CFN template from sm_input
        """
        self.logger.info(
            "Populating SSM parameter values for SM input: {}".format(
                sm_input))
        params = sm_input.get('ResourceProperties')\
            .get('Parameters', {})
        # First transform it from {name: value} to [{'ParameterKey': name},
        # {'ParameterValue': value}]
        # then replace the SSM parameter names with its values
        sm_params = self.param_handler.update_params(transform_params(params))
        # Put it back into the self.state_machine_event
        sm_input.get('ResourceProperties').update({'Parameters': sm_params})
        self.logger.info("Done populating SSM parameter values for SM input:"
                         " {}".format(sm_input))
        return sm_input

    def compare_template_and_params(self, sm_input, stack_name):

        self.logger.info("Comparing the templates and parameters.")
        template_compare, params_compare = False, False
        if stack_name:
            describe_response = self.stack_set\
                .describe_stack_set(stack_name)
            self.logger.info("Print Describe Stack Set Response: {}".format(
                describe_response))
            if describe_response is not None:
                self.logger.info("Found existing stack set.")

                operation_status_flag = self.get_stack_set_operation_status(
                    stack_name)

                if operation_status_flag:
                    self.logger.info("Continuing...")
                else:
                    return operation_status_flag, operation_status_flag

                # Compare template copy - START
                self.logger.info(
                    "Comparing the template of the StackSet:"
                    " {} with local copy of template".format(stack_name))

                template_http_url = sm_input.get('ResourceProperties')\
                    .get('TemplateURL', '')
                if template_http_url:
                    bucket_name, key_name = parse_bucket_key_names(
                        template_http_url)
                    local_template_file = tempfile.mkstemp()[1]
                    self.s3.download_file(bucket_name, key_name,
                                          local_template_file)
                else:
                    self.logger.error("TemplateURL in state machine input "
                                      "is empty. Check state_machine_event"
                                      ":{}".format(sm_input))
                    return False, False

                cfn_template_file = tempfile.mkstemp()[1]
                with open(cfn_template_file, "w") as f:
                    f.write(
                        describe_response.get('StackSet').get('TemplateBody'))
                # cmp function return true of the contents are same
                template_compare = filecmp.cmp(local_template_file,
                                               cfn_template_file, False)
                self.logger.info("Comparing the parameters of the StackSet"
                                 ": {} with local copy of JSON parameters"
                                 " file".format(stack_name))

                params_compare = True
                params = sm_input.get('ResourceProperties')\
                    .get('Parameters', {})
                # template are same - compare parameters (skip if template
                # are not same)
                if template_compare:
                    cfn_params = reverse_transform_params(
                        describe_response.get('StackSet').get('Parameters'))
                    for key, value in params.items():
                        if cfn_params.get(key, '') == value:
                            pass
                        else:
                            params_compare = False
                            break

                self.logger.info(
                    "template_compare={}; params_compare={}".format(
                        template_compare, params_compare))
            else:
                self.logger.info(
                    'Stack Set does not exist. Creating a new stack set ....')
                template_compare, params_compare = True, True
                # set this flag to create the stack set
                self.stack_set_exist = False

        return template_compare, params_compare

    def get_stack_set_operation_status(self, stack_name):
        self.logger.info("Checking the status of last stack set "
                         "operation on {}".format(stack_name))
        response = self.stack_set. \
            list_stack_set_operations(StackSetName=stack_name,
                                      MaxResults=1)
        if response:
            if response.get('Summaries'):
                for instance in response.get('Summaries'):
                    self.logger.info("Status of last stack set "
                                     "operation : {}".format(
                                         instance.get('Status')))
                    if instance.get('Status') != 'SUCCEEDED':
                        self.logger.info(
                            "The last stack operation"
                            " did not succeed. "
                            "Triggering "
                            " Update StackSet for {}".format(stack_name))
                        return False
        return True

    def check_stack_instances_per_account(self, sm_input, stack_name):
        """:return: boolean
        # false: if the SM execution need to make CRUD operations on the StackSet
        # true: if no changes to Stack Set or Stack Instances are required
        """
        flag = False
        account_list = sm_input.get('ResourceProperties') \
            .get("AccountList", [])
        if account_list:
            self.logger.info("Comparing the Stack Instances "
                             "Account & Regions for "
                             "StackSet: {}".format(stack_name))
            expected_region_list = set(
                sm_input.get('ResourceProperties').get("RegionList", []))

            # iterator over accounts in event account list
            for account in account_list:
                actual_region_list = set()

                self.logger.info("### Listing the Stack "
                                 "Instances for StackSet: {}"
                                 " and Account: {} ###".format(
                                     stack_name, account))
                stack_instance_list = self.stack_set. \
                    list_stack_instances_per_account(stack_name,
                                                     account)

                self.logger.info(stack_instance_list)

                if stack_instance_list:
                    for instance in stack_instance_list:
                        if instance.get('Status') \
                                .upper() == 'CURRENT':
                            actual_region_list \
                                .add(instance.get('Region'))
                        else:
                            self.logger.info("Found at least one of"
                                             " the Stack Instances"
                                             " in {} state."
                                             " Triggering Update"
                                             " StackSet for {}".format(
                                                 instance.get('Status'),
                                                 stack_name))
                            return False
                else:
                    self.logger.info("Found no stack instances in"
                                     " account: {},Updating "
                                     "StackSet: {}".format(
                                         account, stack_name))
                    return False

                if expected_region_list. \
                        issubset(actual_region_list):
                    self.logger.info("Found expected regions : {} "
                                     "in deployed stack instances :"
                                     " {}, so skipping Update "
                                     "StackSet for {}".format(
                                         expected_region_list,
                                         actual_region_list, stack_name))
                    flag = True
        else:
            self.logger.info("Found no changes in template "
                             "& parameters, so skipping Update  "
                             "StackSet for {}".format(stack_name))
            flag = True
        return flag

    def monitor_state_machines_execution_status(self):
        if self.list_sm_exec_arns:
            final_status = 'RUNNING'

            while final_status == 'RUNNING':
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(
                        sm_exec_arn)
                    if status == 'RUNNING':
                        final_status = 'RUNNING'
                        time.sleep(int(self.wait_time))
                        break
                    else:
                        final_status = 'COMPLETED'

            err_flag = False
            failed_sm_execution_list = []
            for sm_exec_arn in self.list_sm_exec_arns:
                status = self.state_machine.check_state_machine_status(
                    sm_exec_arn)
                if status == 'SUCCEEDED':
                    continue
                else:
                    failed_sm_execution_list.append(sm_exec_arn)
                    err_flag = True
                    continue

            if err_flag:
                return 'FAILED', failed_sm_execution_list
            else:
                return 'SUCCEEDED', ''
        else:
            self.logger.info("SM Execution List {} is empty, nothing to "
                             "monitor.".format(self.list_sm_exec_arns))
            return None, []
예제 #20
0
from lib.logger import Logger
from lib.state_machine import StateMachine
logger = Logger('critical')
sfn = StateMachine(logger)

trigger_state_machine_response = {
    "executionArn":
    "arn:aws:states:us-east-1:xxxx:execution:TestStateMachine:test-execution-name",
    "startDate": "yyyy-mm-dd"
}

# declare variables
state_machine_arn = 'arn:aws:states:us-east-1:xxxx:execution:TestStateMachine'
input = {}
name = 'test-execution-name'


def test_trigger_state_machine(mocker):
    mocker.patch.object(sfn, 'trigger_state_machine')
    sfn.trigger_state_machine.return_value = trigger_state_machine_response
    response = sfn.trigger_state_machine(state_machine_arn, input, name)
    assert response.get('executionArn') == "%s:%s" % (state_machine_arn, name)