def test_get_metrics_from_finding(mocker):

    expected_response = {
        'generator_id':
        'arn:aws:securityhub:::ruleset/cis-aws-foundations-benchmark/v/1.2.0/rule/1.3',
        'type':
        '1.3 Ensure credentials unused for 90 days or greater are disabled',
        'productArn':
        'arn:aws:securityhub:' + my_region + '::product/aws/securityhub',
        'finding_triggered_by': 'unit-test',
        'region': mocker.ANY
    }

    finding = utils.load_test_data(test_data + 'cis_1-3-iamuser1.json',
                                   my_region).get('detail').get('findings')[0]

    ssmc = boto3.client('ssm', region_name=my_region)
    ssmc_s = Stubber(ssmc)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_uuid)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_version)
    ssmc_s.activate()

    mocker.patch('lib.metrics.Metrics.connect_to_ssm', return_value=ssmc)

    metrics = Metrics({"detail-type": "unit-test"})

    assert metrics.get_metrics_from_finding(finding) == expected_response
Пример #2
0
 def send_execution_data(self):
     try:
         self.logger.info("Executing: " + self.__class__.__name__ + "/" + inspect.stack()[0][3])
         send = Metrics(self.logger)
         data = {"StateMachineExecutionCount": "1"}
         send.metrics(data)
         return self.event
     except:
         return self.event
Пример #3
0
def lambda_handler(event, context):

    LOGGER.debug(event)
    metrics = Metrics(event)
    try:
        for finding_rec in event['detail']['findings']:
            finding = Finding(finding_rec)
            remediate(finding, metrics.get_metrics_from_finding(finding_rec))
    except Exception as e:
        LOGGER.error(e)

    APPLOGGER.flush()  # flush the buffer to CW Logs
 def send_tgw_peering_anonymous_data(self) -> any:
     final_states = ['available', 'deleted']
     if self.event.get('TgwPeeringAttachmentState') in final_states:
         send = Metrics(self.logger)
         data = {
             "TgwPeeringState": self.event.get('TgwPeeringAttachmentState'),
             "Region": environ.get('AWS_REGION'),
             "PeerRegion": self.event.get('PeerRegion'),
             "RequestType": self.event.get('RequestType'),
             "TagEventSource": "TransitGateway",
             "SolutionVersion": environ.get('SOLUTION_VERSION')
         }
         return send.metrics(data)
     else:
         return None
    def test(self):
        cumulative_metrics = Metrics.empty(mode="test")

        for update_info in self.test_updates():
            cumulative_metrics += update_info.metrics

        print(cumulative_metrics)
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_scp,
              staging_bucket):
     self.state_machine = StateMachine(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_scp = sm_arn_scp
     self.manifest = None
     self.list_sm_exec_arns = []
     self.nested_ou_delimiter = ""
     self.staging_bucket = staging_bucket
     self.root_id = None
Пример #7
0
 def send_metrics(self):
     try:
         self.put_ssm()
         self.logger.info(self.params)
         data = {
             "PrincipalType": self.params.get('PrincipalType'),
             "ApprovalNotificationFlag": self.params.get('ApprovalNotification'),
             "AuditTrailRetentionPeriod": self.params.get('AuditTrailRetentionPeriod'),
             "DefaultRoute": self.params.get('DefaultRoute'),
             "Region": get_region(),
             "SolutionVersion": self.params.get('SolutionVersion'),
             "CreatedNewTransitGateway": self.params.get(
                 'CreatedNewTransitGateway')
         }
         send = Metrics(self.logger)
         send.metrics(data)
     except Exception as e:
         self.logger.info(e)
         pass
 def __init__(self, logger, sm_arns_map, staging_bucket, manifest_file_path,
              pipeline_stage, token, execution_mode, primary_account_id):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.sm_arns_map = sm_arns_map
     self.manifest = None
     self.staging_bucket = staging_bucket
     self.manifest_file_path = manifest_file_path
     self.token = token
     self.pipeline_stage = pipeline_stage
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     if execution_mode.lower() == 'sequential':
         self.isSequential = True
     else:
         self.isSequential = False
     self.index = 100
     self.primary_account_id = primary_account_id
Пример #9
0
 def __init__(self, logger, sm_input_list):
     self.logger = logger
     self.sm_input_list = sm_input_list
     self.list_sm_exec_arns = []
     self.stack_set_exist = True
     self.solution_metrics = Metrics(logger)
     self.param_handler = CFNParamsHandler(logger)
     self.state_machine = StateMachine(logger)
     self.stack_set = StackSet(logger)
     self.s3 = S3(logger)
     self.wait_time = os.environ.get('WAIT_TIME')
     self.execution_mode = os.environ.get('EXECUTION_MODE')
def test_send_metrics(mocker):

    expected_response = {
        'Solution': 'SO0111',
        'UUID': '12345678-1234-1234-1234-123412341234',
        'TimeStamp': mocker.ANY,
        'Data': {
            'generator_id':
            'arn:aws:securityhub:::ruleset/cis-aws-foundations-benchmark/v/1.2.0/rule/1.3',
            'type':
            '1.3 Ensure credentials unused for 90 days or greater are disabled',
            'productArn': mocker.ANY,
            'finding_triggered_by': 'unit-test',
            'region': mocker.ANY
        },
        'Version': 'v1.2.0TEST'
    }

    os.environ['sendAnonymousMetrics'] = 'Yes'

    finding = utils.load_test_data(test_data + 'cis_1-3-iamuser1.json',
                                   my_region).get('detail').get('findings')[0]

    ssmc = boto3.client('ssm', region_name=my_region)
    ssmc_s = Stubber(ssmc)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_uuid)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_version)
    ssmc_s.activate()

    mocker.patch('lib.metrics.Metrics.connect_to_ssm', return_value=ssmc)

    metrics = Metrics({"detail-type": "unit-test"})
    metrics_data = metrics.get_metrics_from_finding(finding)

    send_metrics = mocker.patch('lib.metrics.Metrics.post_metrics_to_api',
                                return_value=None)

    metrics.send_metrics(metrics_data)

    send_metrics.assert_called_with(expected_response)
 def __init__(self, logger, wait_time, manifest_file_path, sm_arn_stackset, staging_bucket, execution_mode):
     self.state_machine = StateMachine(logger)
     self.ssm = SSM(logger)
     self.s3 = S3(logger)
     self.send = Metrics(logger)
     self.param_handler = ParamsHandler(logger)
     self.logger = logger
     self.manifest_file_path = manifest_file_path
     self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
     self.wait_time = wait_time
     self.sm_arn_stackset = sm_arn_stackset
     self.manifest = None
     self.list_sm_exec_arns = []
     self.staging_bucket = staging_bucket
     self.root_id = None
     self.uuid = uuid4()
     self.state_machine_event = {}
     if execution_mode.lower() == 'sequential':
         self.logger.info("Running {} mode".format(execution_mode))
         self.sequential_flag = True
     else:
         self.logger.info("Running {} mode".format(execution_mode))
         self.sequential_flag = False
def test_metrics_construction(mocker):

    ssmc = boto3.client('ssm', region_name=my_region)
    ssmc_s = Stubber(ssmc)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_uuid)
    ssmc_s.add_response('get_parameter', mock_ssm_get_parameter_version)
    ssmc_s.activate()

    mocker.patch('lib.metrics.Metrics.connect_to_ssm', return_value=ssmc)

    metrics = Metrics({"detail-type": "unit-test"})

    assert metrics.solution_uuid == "12345678-1234-1234-1234-123412341234"
    assert metrics.solution_version == "v1.2.0TEST"
Пример #13
0
    def metrics(self):
        text = None
        predicted = None

        if self.mode != "train":
            text = self.batch.text
            predicted = self.decoded_inferred_texts

        return Metrics(
            self.mode,
            self.loss,
            self.model_loss,
            self.fooling_loss,
            text,
            predicted
        )
Пример #14
0
        cam.take_photo()
        time.sleep(UPDATE_PHOTO_INTERVAL)


if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s %(message)s', level=LOG_LEVEL)

    logging.info('Starting...')

    cam = Camera()
    prop = Property()
    relays = Relays()
    triac_hat = TriacHat()
    sensors = Sensors()
    growing = Growing()
    metrics = Metrics()
    fan = Fan()
    light = Light()
    humidify = Humidify()
    weather = Weather()

    # Init start settings
    fan.init(triac_hat)

    start_prometheus_exporter(EXPORTER_SERVER_PORT)
    logging.debug('Prometheus exporter listen on 0.0.0.0:{port}'.format(port=EXPORTER_SERVER_PORT))

    update_metrics()
    light_control()
    fan_control()
    humidify_control()
Пример #15
0
class StateMachineTriggerLambda(object):
    def __init__(self, logger, sm_arns_map, staging_bucket, manifest_file_path,
                 pipeline_stage, token, execution_mode, primary_account_id):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.sm_arns_map = sm_arns_map
        self.manifest = None
        self.staging_bucket = staging_bucket
        self.manifest_file_path = manifest_file_path
        self.token = token
        self.pipeline_stage = pipeline_stage
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        if execution_mode.lower() == 'sequential':
            self.isSequential = True
        else:
            self.isSequential = False
        self.index = 100
        self.primary_account_id = primary_account_id

    def _save_sm_exec_arn(self, list_sm_exec_arns):
        if list_sm_exec_arns is not None and type(list_sm_exec_arns) is list:
            self.logger.debug(
                "Saving the token:{} with list of sm_exec_arns:{}".format(
                    self.token, list_sm_exec_arns))
            if len(list_sm_exec_arns) > 0:
                sm_exec_arns = ",".join(
                    list_sm_exec_arns
                )  # Create comma seperated string from list e.g. 'a','b','c'
                self.ssm.put_parameter(
                    self.token,
                    sm_exec_arns)  # Store the list of SM execution ARNs in SSM
            else:
                self.ssm.put_parameter(self.token, 'PASS')
        else:
            raise Exception(
                "Expecting a list of state machine execution ARNs to store in SSM for token:{}, but found nothing to store."
                .format(self.token))

    def _stage_template(self, relative_template_path):
        if relative_template_path.lower().startswith('s3'):
            # Convert the remote template URL s3://bucket-name/object
            # to Virtual-hosted style URL https://bucket-name.s3.amazonaws.com/object
            t = relative_template_path.split("/", 3)
            s3_url = "https://{}.s3.amazonaws.com/{}".format(t[2], t[3])
        else:
            local_file = os.path.join(self.manifest_folder,
                                      relative_template_path)
            remote_file = "{}/{}_{}".format(
                TEMPLATE_KEY_PREFIX, self.token,
                relative_template_path[relative_template_path.rfind('/') + 1:])
            logger.info(
                "Uploading the template file: {} to S3 bucket: {} and key: {}".
                format(local_file, self.staging_bucket, remote_file))
            self.s3.upload_file(self.staging_bucket, local_file, remote_file)
            s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/',
                                       self.staging_bucket, '/', remote_file)
        return s3_url

    def _download_remote_file(self, remote_s3_path):
        _file = tempfile.mkstemp()[1]
        t = remote_s3_path.split("/", 3)  # s3://bucket-name/key
        remote_bucket = t[2]  # Bucket name
        remote_key = t[3]  # Key
        logger.info("Downloading {}/{} from S3 to {}".format(
            remote_bucket, remote_key, _file))
        self.s3.download_file(remote_bucket, remote_key, _file)
        return _file

    def _load_policy(self, relative_policy_path):
        if relative_policy_path.lower().startswith('s3'):
            policy_file = self._download_remote_file(relative_policy_path)
        else:
            policy_file = os.path.join(self.manifest_folder,
                                       relative_policy_path)

        logger.info("Parsing the policy file: {}".format(policy_file))

        with open(policy_file, 'r') as content_file:
            policy_file_content = content_file.read()

        #Check if valid json
        json.loads(policy_file_content)
        #Return the Escaped JSON text
        return policy_file_content.replace('"', '\"').replace('\n', '\r\n')

    def _load_params(self, relative_parameter_path, account=None, region=None):
        if relative_parameter_path.lower().startswith('s3'):
            parameter_file = self._download_remote_file(
                relative_parameter_path)
        else:
            parameter_file = os.path.join(self.manifest_folder,
                                          relative_parameter_path)

        logger.info("Parsing the parameter file: {}".format(parameter_file))

        with open(parameter_file, 'r') as content_file:
            parameter_file_content = content_file.read()

        params = json.loads(parameter_file_content)
        if account is not None:
            #Deploying Core resource Stack Set
            # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
            sm_params = self.param_handler.update_params(
                params, account, region, False)
        else:
            # Deploying Baseline resource Stack Set
            sm_params = self.param_handler.update_params(params)

        logger.info("Input Parameters for State Machine: {}".format(sm_params))
        return sm_params

    def _load_template_rules(self, relative_rules_path):
        rules_file = os.path.join(self.manifest_folder, relative_rules_path)
        logger.info("Parsing the template rules file: {}".format(rules_file))

        with open(rules_file, 'r') as content_file:
            rules_file_content = content_file.read()

        rules = json.loads(rules_file_content)

        logger.info(
            "Template Constraint Rules for State Machine: {}".format(rules))

        return rules

    def _populate_ssm_params(self, sm_input):
        # The scenario is if you have one core resource that exports output from CFN stack to SSM parameter
        # and then the next core resource reads the SSM parameter as input, then it has to wait for the first core resource to
        # finish; read the SSM parameters and use its value as input for second core resource's input for SM
        # Get the parameters for CFN template from sm_input
        logger.debug("Populating SSM parameter values for SM input: {}".format(
            sm_input))
        params = sm_input.get('ResourceProperties').get('Parameters', {})
        # First transform it from {name: value} to [{'ParameterKey': name}, {'ParameterValue': value}]
        # then replace the SSM parameter names with its values
        sm_params = self.param_handler.update_params(transform_params(params))
        # Put it back into the sm_input
        sm_input.get('ResourceProperties').update({'Parameters': sm_params})
        logger.debug(
            "Done populating SSM parameter values for SM input: {}".format(
                sm_input))
        return sm_input

    def _create_ssm_input_map(self, ssm_parameters):
        ssm_input_map = {}

        for ssm_parameter in ssm_parameters:
            key = ssm_parameter.name
            value = ssm_parameter.value
            ssm_value = self.param_handler.update_params(
                transform_params({key: value}))
            ssm_input_map.update(ssm_value)

        return ssm_input_map

    def _create_state_machine_input_map(self,
                                        input_params,
                                        request_type='Create'):
        request = {}
        request.update({'RequestType': request_type})
        request.update({'ResourceProperties': input_params})

        return request

    def _create_account_state_machine_input_map(self,
                                                ou_name,
                                                account_name='',
                                                account_email='',
                                                ssm_map=None):
        input_params = {}
        input_params.update({'OUName': ou_name})
        input_params.update({'AccountName': account_name})
        input_params.update({'AccountEmail': account_email})
        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})
        return self._create_state_machine_input_map(input_params)

    def _create_stack_set_state_machine_input_map(
            self,
            stack_set_name,
            template_url,
            parameters,
            account_list=[],
            regions_list=[],
            ssm_map=None,
            capabilities='CAPABILITY_NAMED_IAM'):
        input_params = {}
        input_params.update({'StackSetName': sanitize(stack_set_name)})
        input_params.update({'TemplateURL': template_url})
        input_params.update({'Parameters': parameters})
        input_params.update({'Capabilities': capabilities})

        if len(account_list) > 0:
            input_params.update({'AccountList': account_list})
            if len(regions_list) > 0:
                input_params.update({'RegionList': regions_list})
            else:
                input_params.update({'RegionList': [self.manifest.region]})
        else:
            input_params.update({'AccountList': ''})
            input_params.update({'RegionList': ''})

        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})

        return self._create_state_machine_input_map(input_params)

    def _create_service_control_policy_state_machine_input_map(
            self, policy_name, policy_content, policy_desc=''):
        input_params = {}
        policy_doc = {}
        policy_doc.update({'Name': sanitize(policy_name)})
        policy_doc.update({'Description': policy_desc})
        policy_doc.update({'Content': policy_content})
        input_params.update({'PolicyDocument': policy_doc})
        input_params.update({'AccountId': ''})
        input_params.update({'PolicyList': []})
        input_params.update({'Operation': ''})
        return self._create_state_machine_input_map(input_params)

    def _create_service_catalog_state_machine_input_map(
            self, portfolio, product):
        input_params = {}

        sc_portfolio = {}
        sc_portfolio.update({'PortfolioName': sanitize(portfolio.name, True)})
        sc_portfolio.update(
            {'PortfolioDescription': sanitize(portfolio.description, True)})
        sc_portfolio.update(
            {'PortfolioProvider': sanitize(portfolio.owner, True)})
        ssm_value = self.param_handler.update_params(
            transform_params({'principal_role': portfolio.principal_role}))
        sc_portfolio.update({'PrincipalArn': ssm_value.get('principal_role')})

        sc_product = {}
        sc_product.update({'ProductName': sanitize(product.name, True)})
        sc_product.update({'ProductDescription': product.description})
        sc_product.update({'ProductOwner': sanitize(portfolio.owner, True)})
        if product.hide_old_versions is True:
            sc_product.update({'HideOldVersions': 'Yes'})
        else:
            sc_product.update({'HideOldVersions': 'No'})
        ssm_value = self.param_handler.update_params(
            transform_params(
                {'launch_constraint_role': product.launch_constraint_role}))
        sc_product.update({'RoleArn': ssm_value.get('launch_constraint_role')})

        ec2 = EC2(self.logger, environ.get('AWS_REGION'))
        region_list = []
        for region in ec2.describe_regions():
            region_list.append(region.get('RegionName'))

        if os.path.isfile(
                os.path.join(self.manifest_folder, product.skeleton_file)):
            lambda_arn_param = get_env_var('lambda_arn_param_name')
            lambda_arn = self.ssm.get_parameter(lambda_arn_param)
            portfolio_index = self.manifest.portfolios.index(portfolio)
            product_index = self.manifest.portfolios[
                portfolio_index].products.index(product)
            product_name = self.manifest.portfolios[portfolio_index].products[
                product_index].name
            logger.info(
                "Generating the product template for {} from {}".format(
                    product_name,
                    os.path.join(self.manifest_folder, product.skeleton_file)))
            j2loader = jinja2.FileSystemLoader(self.manifest_folder)
            j2env = jinja2.Environment(loader=j2loader)
            j2template = j2env.get_template(product.skeleton_file)
            template_url = None
            if product.product_type.lower() == 'baseline':
                # j2result = j2template.render(manifest=self.manifest, portfolio_index=portfolio_index,
                #                              product_index=product_index, lambda_arn=lambda_arn, uuid=uuid.uuid4(),
                #                              regions=region_list)
                template_url = self._stage_template(product.skeleton_file +
                                                    ".template")
            elif product.product_type.lower() == 'optional':
                if len(product.template_file) > 0:
                    template_url = self._stage_template(product.template_file)
                    j2result = j2template.render(
                        manifest=self.manifest,
                        portfolio_index=portfolio_index,
                        product_index=product_index,
                        lambda_arn=lambda_arn,
                        uuid=uuid.uuid4(),
                        template_url=template_url)
                    generated_avm_template = os.path.join(
                        self.manifest_folder,
                        product.skeleton_file + ".generated.template")
                    logger.info(
                        "Writing the generated product template to {}".format(
                            generated_avm_template))
                    with open(generated_avm_template, "w") as fh:
                        fh.write(j2result)
                    template_url = self._stage_template(generated_avm_template)
                else:
                    raise Exception(
                        "Missing template_file location for portfolio:{} and product:{} in Manifest file"
                        .format(portfolio.name, product.name))

        else:
            raise Exception(
                "Missing skeleton_file for portfolio:{} and product:{} in Manifest file"
                .format(portfolio.name, product.name))

        artifact_params = {}
        artifact_params.update({'Info': {'LoadTemplateFromURL': template_url}})
        artifact_params.update({'Type': 'CLOUD_FORMATION_TEMPLATE'})
        artifact_params.update({'Description': product.description})
        sc_product.update({'ProvisioningArtifactParameters': artifact_params})

        try:
            if product.rules_file:
                rules = self._load_template_rules(product.rules_file)
                sc_product.update({'Rules': rules})
        except Exception as e:
            logger.error(e)

        input_params.update({'SCPortfolio': sc_portfolio})
        input_params.update({'SCProduct': sc_product})

        return self._create_state_machine_input_map(input_params)

    def _create_launch_avm_state_machine_input_map(self, portfolio, product,
                                                   accounts):
        input_params = {}
        input_params.update({'PortfolioName': sanitize(portfolio, True)})
        input_params.update({'ProductName': sanitize(product, True)})
        input_params.update({'ProvisioningParametersList': accounts})
        return self._create_state_machine_input_map(input_params)

    def _run_or_queue_state_machine(self, sm_input, sm_arn, list_sm_exec_arns,
                                    sm_name):
        logger.info("State machine Input: {}".format(sm_input))
        exec_name = "%s-%s-%s" % (sm_input.get('RequestType'),
                                  sm_name.replace(" ", ""),
                                  time.strftime("%Y-%m-%dT%H-%M-%S"))
        # If Sequential, kick off the first SM, and save the state machine input JSON
        # for the rest in SSM parameter store under /job_id/0 tree
        if self.isSequential:
            if self.index == 100:
                sm_input = self._populate_ssm_params(sm_input)
                sm_exec_arn = self.state_machine.trigger_state_machine(
                    sm_arn, sm_input, exec_name)
                list_sm_exec_arns.append(sm_exec_arn)
            else:
                param_name = "/%s/%s" % (self.token, self.index)
                self.ssm.put_parameter(param_name, json.dumps(sm_input))
        # Else if Parallel, execute all SM at regular interval of wait_time
        else:
            sm_input = self._populate_ssm_params(sm_input)
            sm_exec_arn = self.state_machine.trigger_state_machine(
                sm_arn, sm_input, exec_name)
            time.sleep(int(wait_time))  # Sleeping for sometime
            list_sm_exec_arns.append(sm_exec_arn)
        self.index = self.index + 1

    def _deploy_resource(self,
                         resource,
                         sm_arn,
                         list_sm_exec_arns,
                         account_id=None):
        template_full_path = self._stage_template(resource.template_file)
        params = {}
        deploy_resource_flag = True
        if resource.parameter_file:
            if len(resource.regions) > 0:
                params = self._load_params(resource.parameter_file, account_id,
                                           resource.regions[0])
            else:
                params = self._load_params(resource.parameter_file, account_id,
                                           self.manifest.region)

        ssm_map = self._create_ssm_input_map(resource.ssm_parameters)

        if account_id is not None:
            #Deploying Core resource Stack Set
            stack_name = "AWS-Landing-Zone-{}".format(resource.name)
            sm_input = self._create_stack_set_state_machine_input_map(
                stack_name, template_full_path, params, [str(account_id)],
                resource.regions, ssm_map)
        else:
            #Deploying Baseline resource Stack Set
            stack_name = "AWS-Landing-Zone-Baseline-{}".format(resource.name)
            sm_input = self._create_stack_set_state_machine_input_map(
                stack_name, template_full_path, params, [], [], ssm_map)

            stack_set = StackSet(self.logger)
            response = stack_set.describe_stack_set(stack_name)
            if response is not None:
                self.logger.info("Found existing stack set.")
                self.logger.info(
                    "Comparing the template of the StackSet: {} with local copy of template"
                    .format(stack_name))
                relative_template_path = resource.template_file
                if relative_template_path.lower().startswith('s3'):
                    local_template_file = self._download_remote_file(
                        relative_template_path)
                else:
                    local_template_file = os.path.join(self.manifest_folder,
                                                       relative_template_path)

                cfn_template_file = tempfile.mkstemp()[1]
                with open(cfn_template_file, "w") as f:
                    f.write(response.get('StackSet').get('TemplateBody'))

                template_compare = filecmp.cmp(local_template_file,
                                               cfn_template_file)
                self.logger.info(
                    "Comparing the parameters of the StackSet: {} with local copy of JSON parameters file"
                    .format(stack_name))
                params_compare = True
                if template_compare:
                    cfn_params = reverse_transform_params(
                        response.get('StackSet').get('Parameters'))
                    for key, value in params.items():
                        if cfn_params.get(key, '') == value:
                            pass
                        else:
                            params_compare = False
                            break

                self.logger.info(
                    "template_compare={}".format(template_compare))
                self.logger.info("params_compare={}".format(params_compare))
                if template_compare and params_compare:
                    deploy_resource_flag = False
                    self.logger.info(
                        "Found no changes in template & parameters, so skipping Update StackSet for {}"
                        .format(stack_name))

        if deploy_resource_flag:
            self._run_or_queue_state_machine(sm_input, sm_arn,
                                             list_sm_exec_arns, stack_name)

    def start_core_account_sm(self, sm_arn_account):
        try:
            logger.info("Setting the lock_down_stack_sets_role={}".format(
                self.manifest.lock_down_stack_sets_role))

            if self.manifest.lock_down_stack_sets_role is True:
                self.ssm.put_parameter('lock_down_stack_sets_role_flag', 'yes')
            else:
                self.ssm.put_parameter('lock_down_stack_sets_role_flag', 'no')

            # Send metric - pipeline run count
            data = {"PipelineRunCount": "1"}
            self.send.metrics(data)

            logger.info("Processing Core Accounts from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            for ou in self.manifest.organizational_units:
                ou_name = ou.name
                logger.info(
                    "Generating the state machine input json for OU: {}".
                    format(ou_name))

                if len(ou.core_accounts) == 0:
                    # Empty OU with no Accounts
                    sm_input = self._create_account_state_machine_input_map(
                        ou_name)
                    self._run_or_queue_state_machine(sm_input, sm_arn_account,
                                                     list_sm_exec_arns,
                                                     ou_name)

                for account in ou.core_accounts:
                    account_name = account.name

                    if account_name.lower() == 'primary':
                        account_email = ''
                    else:
                        account_email = account.email
                        if not account_email:
                            raise Exception(
                                "Failed to retrieve the email address for the Account: {}"
                                .format(account_name))

                    ssm_map = self._create_ssm_input_map(
                        account.ssm_parameters)

                    sm_input = self._create_account_state_machine_input_map(
                        ou_name, account_name, account_email, ssm_map)
                    self._run_or_queue_state_machine(sm_input, sm_arn_account,
                                                     list_sm_exec_arns,
                                                     account_name)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_core_resource_sm(self, sm_arn_stack_set):
        try:
            logger.info("Parsing Core Resources from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for ou in self.manifest.organizational_units:
                for account in ou.core_accounts:
                    account_name = account.name
                    account_id = ''
                    for ssm_parameter in account.ssm_parameters:
                        if ssm_parameter.value == '$[AccountId]':
                            account_id = self.ssm.get_parameter(
                                ssm_parameter.name)

                    if account_id == '':
                        raise Exception(
                            "Missing required SSM parameter: {} to retrive the account Id of Account: {} defined in Manifest"
                            .format(ssm_parameter.name, account_name))

                    for resource in account.core_resources:
                        # Count number of stacksets
                        count += 1
                        if resource.deploy_method.lower() == 'stack_set':
                            self._deploy_resource(resource, sm_arn_stack_set,
                                                  list_sm_exec_arns,
                                                  account_id)
                        else:
                            raise Exception(
                                "Unsupported deploy_method: {} found for resource {} and Account: {} in Manifest"
                                .format(resource.deploy_method, resource.name,
                                        account_name))
            data = {"CoreAccountStackSetCount": str(count)}
            self.send.metrics(data)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_control_policy_sm(self, sm_arn_scp):
        try:
            logger.info("Processing SCPs from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for policy in self.manifest.organization_policies:
                policy_content = self._load_policy(policy.policy_file)
                sm_input = self._create_service_control_policy_state_machine_input_map(
                    policy.name, policy_content, policy.description)
                self._run_or_queue_state_machine(sm_input, sm_arn_scp,
                                                 list_sm_exec_arns,
                                                 policy.name)
                # Count number of stacksets
                count += 1
            self._save_sm_exec_arn(list_sm_exec_arns)
            data = {"SCPPolicyCount": str(count)}
            self.send.metrics(data)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_catalog_sm(self, sm_arn_sc):
        try:
            logger.info(
                "Processing Service catalogs section from {} file".format(
                    self.manifest_file_path))
            list_sm_exec_arns = []
            for portfolio in self.manifest.portfolios:
                for product in portfolio.products:
                    sm_input = self._create_service_catalog_state_machine_input_map(
                        portfolio, product)
                    self._run_or_queue_state_machine(sm_input, sm_arn_sc,
                                                     list_sm_exec_arns,
                                                     product.name)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_baseline_resources_sm(self, sm_arn_stack_set):
        try:
            logger.info("Parsing Basline Resources from {} file".format(
                self.manifest_file_path))
            list_sm_exec_arns = []
            count = 0
            for resource in self.manifest.baseline_resources:
                if resource.deploy_method.lower() == 'stack_set':
                    self._deploy_resource(resource, sm_arn_stack_set,
                                          list_sm_exec_arns)
                    # Count number of stacksets
                    count += 1
                else:
                    raise Exception(
                        "Unsupported deploy_method: {} found for resource {} in Manifest"
                        .format(resource.deploy_method, resource.name))
            data = {"BaselineStackSetCount": str(count)}
            self.send.metrics(data)
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_launch_avm(self, sm_arn_launch_avm):
        try:
            logger.info("Starting the launch AVM trigger")
            list_sm_exec_arns = []
            ou_id_map = {}

            org = Organizations(self.logger)
            response = org.list_roots()
            self.logger.info("List roots Response")
            self.logger.info(response)
            root_id = response['Roots'][0].get('Id')

            response = org.list_organizational_units_for_parent(
                ParentId=root_id)
            next_token = response.get('NextToken', None)

            for ou in response['OrganizationalUnits']:
                ou_id_map.update({ou.get('Name'): ou.get('Id')})

            while next_token is not None:
                response = org.list_organizational_units_for_parent(
                    ParentId=root_id, NextToken=next_token)
                next_token = response.get('NextToken', None)
                for ou in response['OrganizationalUnits']:
                    ou_id_map.update({ou.get('Name'): ou.get('Id')})

            self.logger.info("ou_id_map={}".format(ou_id_map))

            for portfolio in self.manifest.portfolios:
                for product in portfolio.products:
                    if product.product_type.lower() == 'baseline':
                        _params = self._load_params(product.parameter_file)
                        logger.info(
                            "Input parameters format for AVM: {}".format(
                                _params))
                        list_of_accounts = []
                        for ou in product.apply_baseline_to_accounts_in_ou:
                            self.logger.debug(
                                "Looking up ou={} in ou_id_map".format(ou))
                            ou_id = ou_id_map.get(ou)
                            self.logger.debug(
                                "ou_id={} for ou={} in ou_id_map".format(
                                    ou_id, ou))

                            response = org.list_accounts_for_parent(ou_id)
                            self.logger.debug(
                                "List Accounts for Parent Response")
                            self.logger.debug(response)
                            for account in response.get('Accounts'):
                                params = _params.copy()
                                for key, value in params.items():
                                    if value.lower() == 'accountemail':
                                        params.update(
                                            {key: account.get('Email')})
                                    elif value.lower() == 'accountname':
                                        params.update(
                                            {key: account.get('Name')})
                                    elif value.lower() == 'orgunitname':
                                        params.update({key: ou})

                                logger.info(
                                    "Input parameters format for Account: {} are {}"
                                    .format(account.get('Name'), params))

                                list_of_accounts.append(params)

                        if len(list_of_accounts) > 0:
                            sm_input = self._create_launch_avm_state_machine_input_map(
                                portfolio.name, product.name, list_of_accounts)
                            logger.info(
                                "Launch AVM state machine Input: {}".format(
                                    sm_input))
                            exec_name = "%s-%s-%s" % (
                                sm_input.get('RequestType'), "Launch-AVM",
                                time.strftime("%Y-%m-%dT%H-%M-%S"))
                            sm_exec_arn = self.state_machine.trigger_state_machine(
                                sm_arn_launch_avm, sm_input, exec_name)
                            list_sm_exec_arns.append(sm_exec_arn)

                    time.sleep(int(wait_time))  # Sleeping for sometime
            self._save_sm_exec_arn(list_sm_exec_arns)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def trigger_state_machines(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)

            if self.pipeline_stage == 'core_accounts':
                self.start_core_account_sm(self.sm_arns_map.get('account'))
            elif self.pipeline_stage == 'core_resources':
                self.start_core_resource_sm(self.sm_arns_map.get('stack_set'))
            elif self.pipeline_stage == 'service_control_policy':
                self.start_service_control_policy_sm(
                    self.sm_arns_map.get('service_control_policy'))
            elif self.pipeline_stage == 'service_catalog':
                self.start_service_catalog_sm(
                    self.sm_arns_map.get('service_catalog'))
            elif self.pipeline_stage == 'baseline_resources':
                self.start_baseline_resources_sm(
                    self.sm_arns_map.get('stack_set'))
            elif self.pipeline_stage == 'launch_avm':
                self.start_launch_avm(self.sm_arns_map.get('launch_avm'))

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def get_state_machines_execution_status(self):
        try:
            sm_exec_arns = self.ssm.get_parameter(self.token)

            if sm_exec_arns == 'PASS':
                self.ssm.delete_parameter(self.token)
                return 'SUCCEEDED', ''
            else:
                list_sm_exec_arns = sm_exec_arns.split(
                    ","
                )  # Create a list from comma seperated string e.g. ['a','b','c']

                for sm_exec_arn in list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(
                        sm_exec_arn)
                    if status == 'RUNNING':
                        return 'RUNNING', ''
                    elif status == 'SUCCEEDED':
                        continue
                    else:
                        self.ssm.delete_parameter(self.token)
                        self.ssm.delete_parameters_by_path(self.token)
                        err_msg = "State Machine Execution Failed, please check the Step function console for State Machine Execution ARN: {}".format(
                            sm_exec_arn)
                        return 'FAILED', err_msg

                if self.isSequential:
                    _params_list = self.ssm.get_parameters_by_path(self.token)
                    if _params_list:
                        params_list = sorted(_params_list,
                                             key=lambda i: i['Name'])
                        sm_input = json.loads(params_list[0].get('Value'))
                        if self.pipeline_stage == 'core_accounts':
                            sm_arn = self.sm_arns_map.get('account')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'OUName') + "-" + sm_input.get(
                                    'ResourceProperties').get('AccountName')

                            account_name = sm_input.get(
                                'ResourceProperties').get('AccountName')
                            if account_name.lower() == 'primary':
                                org = Organizations(self.logger)
                                response = org.describe_account(
                                    self.primary_account_id)
                                account_email = response.get('Account').get(
                                    'Email', '')
                                sm_input.get('ResourceProperties').update(
                                    {'AccountEmail': account_email})

                        elif self.pipeline_stage == 'core_resources':
                            sm_arn = self.sm_arns_map.get('stack_set')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'StackSetName')
                            sm_input = self._populate_ssm_params(sm_input)
                        elif self.pipeline_stage == 'service_control_policy':
                            sm_arn = self.sm_arns_map.get(
                                'service_control_policy')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'PolicyDocument').get('Name')
                        elif self.pipeline_stage == 'service_catalog':
                            sm_arn = self.sm_arns_map.get('service_catalog')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'SCProduct').get('ProductName')
                        elif self.pipeline_stage == 'baseline_resources':
                            sm_arn = self.sm_arns_map.get('stack_set')
                            sm_name = sm_input.get('ResourceProperties').get(
                                'StackSetName')
                            sm_input = self._populate_ssm_params(sm_input)

                        exec_name = "%s-%s-%s" % (sm_input.get(
                            'RequestType'), sm_name.replace(
                                " ", ""), time.strftime("%Y-%m-%dT%H-%M-%S"))
                        sm_exec_arn = self.state_machine.trigger_state_machine(
                            sm_arn, sm_input, exec_name)
                        self._save_sm_exec_arn([sm_exec_arn])
                        self.ssm.delete_parameter(params_list[0].get('Name'))
                        return 'RUNNING', ''

                self.ssm.delete_parameter(self.token)
                return 'SUCCEEDED', ''

        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
Пример #16
0
import logging
from dateutil import parser
from flask import Flask, abort, jsonify, request, send_file

from lib.fan import Fan
from lib.humidify import Humidify
from lib.metrics import Metrics
from lib.growing import Growing
from flask_cors import CORS

from lib.properties import Property

from settings import PHOTOS_DIR

m = Metrics()
g = Growing()
p = Property()
app = Flask(__name__)
cors = CORS(app)
fan = Fan()
h = Humidify()


@app.errorhandler(404)
def resource_not_found(e):
    return jsonify(error=str(e)), 404


@app.route('/api/metrics/<metric>')
def metrics(metric):
class LaunchSCP(object):
    def __init__(self, logger, wait_time, manifest_file_path, sm_arn_scp,
                 staging_bucket):
        self.state_machine = StateMachine(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_scp = sm_arn_scp
        self.manifest = None
        self.list_sm_exec_arns = []
        self.nested_ou_delimiter = ""
        self.staging_bucket = staging_bucket
        self.root_id = None

    def _create_service_control_policy_state_machine_input_map(
            self, policy_name, policy_full_path, policy_desc='', ou_list=[]):
        input_params = {}
        policy_doc = {}
        policy_doc.update({'Name': sanitize(policy_name)})
        policy_doc.update({'Description': policy_desc})
        policy_doc.update({'PolicyURL': policy_full_path})
        input_params.update({'PolicyDocument': policy_doc})
        input_params.update({'AccountId': ''})
        input_params.update({'PolicyList': []})
        input_params.update({'Operation': ''})
        input_params.update({'OUList': ou_list})
        input_params.update({'OUNameDelimiter': self.nested_ou_delimiter})
        return self._create_state_machine_input_map(input_params)

    def _create_state_machine_input_map(self,
                                        input_params,
                                        request_type='Create'):
        request = {}
        request.update({'RequestType': request_type})
        request.update({'ResourceProperties': input_params})

        return request

    def _stage_template(self, relative_template_path):
        if relative_template_path.lower().startswith('s3'):
            # Convert the S3 URL s3://bucket-name/object
            # to HTTP URL https://s3.amazonaws.com/bucket-name/object
            s3_url = convert_s3_url_to_http_url(relative_template_path)
        else:
            local_file = os.path.join(self.manifest_folder,
                                      relative_template_path)
            # remote_file = "{}/{}_{}".format(TEMPLATE_KEY_PREFIX, self.token, relative_template_path[relative_template_path.rfind('/')+1:])
            remote_file = "{}/{}".format(TEMPLATE_KEY_PREFIX,
                                         relative_template_path)
            logger.info(
                "Uploading the template file: {} to S3 bucket: {} and key: {}".
                format(local_file, self.staging_bucket, remote_file))
            self.s3.upload_file(self.staging_bucket, local_file, remote_file)
            s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/',
                                       self.staging_bucket, '/', remote_file)
        return s3_url

    def _run_or_queue_state_machine(self, sm_input, sm_arn, sm_name):
        logger.info("State machine Input: {}".format(sm_input))
        exec_name = "%s-%s-%s" % (sm_input.get('RequestType'),
                                  trim_length(sm_name.replace(" ", ""), 50),
                                  time.strftime("%Y-%m-%dT%H-%M-%S"))

        # execute all SM at regular interval of wait_time
        sm_exec_arn = self.state_machine.trigger_state_machine(
            sm_arn, sm_input, exec_name)
        time.sleep(int(wait_time))  # Sleeping for sometime
        self.list_sm_exec_arns.append(sm_exec_arn)

    def trigger_service_control_policy_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.start_service_control_policy_sm()
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def monitor_state_machines_execution_status(self):
        try:
            if self.list_sm_exec_arns:
                final_status = 'RUNNING'

                while final_status == 'RUNNING':
                    for sm_exec_arn in self.list_sm_exec_arns:
                        status = self.state_machine.check_state_machine_status(
                            sm_exec_arn)
                        if status == 'RUNNING':
                            final_status = 'RUNNING'
                            time.sleep(int(wait_time))
                            break
                        else:
                            final_status = 'COMPLETED'

                err_flag = False
                failed_sm_execution_list = []
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(
                        sm_exec_arn)
                    if status == 'SUCCEEDED':
                        continue
                    else:
                        failed_sm_execution_list.append(sm_exec_arn)
                        err_flag = True
                        continue

                if err_flag:
                    return 'FAILED', failed_sm_execution_list
                else:
                    return 'SUCCEEDED', ''
            else:
                self.logger.info(
                    "SM Execution List {} is empty, nothing to monitor.".
                    format(self.list_sm_exec_arns))
                return None, []
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise

    def start_service_control_policy_sm(self):
        try:
            logger.info("Processing SCPs from {} file".format(
                self.manifest_file_path))
            count = 0

            for policy in self.manifest.organization_policies:
                # Generate the list of OUs to attach this SCP to
                ou_list = []
                attach_ou_list = set(policy.apply_to_accounts_in_ou)

                for ou in attach_ou_list:
                    ou_list.append((ou, 'Attach'))

                policy_full_path = self._stage_template(policy.policy_file)
                sm_input = self._create_service_control_policy_state_machine_input_map(
                    policy.name, policy_full_path, policy.description, ou_list)
                self._run_or_queue_state_machine(sm_input, sm_arn_scp,
                                                 policy.name)

                # Count number of SCPs
                count += 1

            data = {"SCPPolicyCount": str(count)}
            self.send.metrics(data)

            # Exit where there are no organization policies
            if count == 0:
                logger.info("No organization policies are found.")
                sys.exit(0)
            return
        except Exception as e:
            message = {
                'FILE': __file__.split('/')[-1],
                'METHOD': inspect.stack()[0][3],
                'EXCEPTION': str(e)
            }
            self.logger.exception(message)
            raise
Пример #18
0
def notify(finding, message, logger, cwlogs=False, sechub=True, sns=False):
    """
    Consolidates several outputs to a single call.

    Attributes
    ----------
    finding: finding object for which notification is to be done
    message: dict of notification data:
        {
            'Account': string,
            'AffectedOject': string,
            'Remediation': string,
            'State': string,
            'Note': string
        }
    logger: logger object for logging to stdout
    cwlogs: boolean - log to application log group?
    sechub: boolean - update Security Hub notes on the finding?
    sns: boolean - send to sns topic?
    """

    remediation_adj = ''
    if 'State' in message:
        if message['State'] == 'RESOLVED':
            remediation_adj = 'remediation was successful'
        elif message['State'] == 'INITIAL':
            remediation_adj = 'remediation started'
        elif message['State'] == 'FAILED':
            remediation_adj = 'remediation failed. Please remediate manually'
        if 'Note' not in message or not message['Note']:
            message['Note'] = '"' + message.get('Remediation', 'error missing remediation') +\
            '" ' + remediation_adj
    else:
        message['State'] = 'INFO'

    if 'Note' not in message or not message['Note']:
        message['Note'] = 'error - missing note'

    #send metrics
    try:
        metrics_data = message['metrics_data']
        metrics = Metrics({'detail-type': 'None'})
        metrics_data['status'] = message['State']
        metrics.send_metrics(metrics_data)
    except Exception as e:
        logger.error(e)
        logger.error('Failed to send metrics')

    # lambda logs - always
    logger.info(
        message.get('State', 'INFO') + ': ' + message.get('Note') +\
        ', Account Id: ' + message.get('Account', 'error') + \
        ', Resource: ' + message.get('AffectedObject', 'error')
    )

    # log to application log
    if cwlogs:
        # to take advantage of buffering, the caller controls the
        # connection.
        cwlogs.add_message(
            message.get('State') + ': ' + message.get('Note') +\
            ', Account Id: ' + message.get('Account', 'error') + \
            ', Resource: ' + message.get('AffectedObject', 'error')
        )

    if sechub:
        if message.get('State') == 'RESOLVED':
            finding.resolve(message.get('State') + ': ' + message.get('Note'))
        elif message.get('State') == 'INITIAL':
            finding.flag(message.get('State') + ': ' + message.get('Note'))
        else:
            finding.update_text(
                message.get('State', 'INFO') + ': ' + message.get('Note'))

    if sns:
        try:
            sns.postit('SO0111-SHARR_Topic', message, AWS_REGION)
        except Exception as e:
            logger.error(e)
            logger.error('Unable to send to sns')
class DeployStackSetStateMachine(object):
    def __init__(self, logger, wait_time, manifest_file_path, sm_arn_stackset, staging_bucket, execution_mode):
        self.state_machine = StateMachine(logger)
        self.ssm = SSM(logger)
        self.s3 = S3(logger)
        self.send = Metrics(logger)
        self.param_handler = ParamsHandler(logger)
        self.logger = logger
        self.manifest_file_path = manifest_file_path
        self.manifest_folder = manifest_file_path[:-len(MANIFEST_FILE_NAME)]
        self.wait_time = wait_time
        self.sm_arn_stackset = sm_arn_stackset
        self.manifest = None
        self.list_sm_exec_arns = []
        self.staging_bucket = staging_bucket
        self.root_id = None
        self.uuid = uuid4()
        self.state_machine_event = {}
        if execution_mode.lower() == 'sequential':
            self.logger.info("Running {} mode".format(execution_mode))
            self.sequential_flag = True
        else:
            self.logger.info("Running {} mode".format(execution_mode))
            self.sequential_flag = False

    def _stage_template(self, relative_template_path):
        try:
            if relative_template_path.lower().startswith('s3'):
                # Convert the S3 URL s3://bucket-name/object to HTTP URL https://s3.amazonaws.com/bucket-name/object
                s3_url = convert_s3_url_to_http_url(relative_template_path)
            else:
                local_file = os.path.join(self.manifest_folder, relative_template_path)
                remote_file = "{}/{}_{}".format(TEMPLATE_KEY_PREFIX, self.uuid, relative_template_path)
                logger.info("Uploading the template file: {} to S3 bucket: {} and key: {}".format(local_file,
                                                                                                  self.staging_bucket,
                                                                                                  remote_file))
                self.s3.upload_file(self.staging_bucket, local_file, remote_file)
                s3_url = "{}{}{}{}".format('https://s3.amazonaws.com/', self.staging_bucket, '/', remote_file)
            return s3_url
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _load_params(self, relative_parameter_path, account=None, region=None):
        try:
            if relative_parameter_path.lower().startswith('s3'):
                parameter_file = download_remote_file(self.logger, relative_parameter_path)
            else:
                parameter_file = os.path.join(self.manifest_folder, relative_parameter_path)

            logger.info("Parsing the parameter file: {}".format(parameter_file))

            with open(parameter_file, 'r') as content_file:
                parameter_file_content = content_file.read()

            params = json.loads(parameter_file_content)
            if account is not None:
                # Deploying Core resource Stack Set
                # The last parameter is set to False, because we do not want to replace the SSM parameter values yet.
                sm_params = self.param_handler.update_params(params, account, region, False)
            else:
                # Deploying Baseline resource Stack Set
                sm_params = self.param_handler.update_params(params)

            logger.info("Input Parameters for State Machine: {}".format(sm_params))
            return sm_params
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_ssm_input_map(self, ssm_parameters):
        try:
            ssm_input_map = {}

            for ssm_parameter in ssm_parameters:
                key = ssm_parameter.name
                value = ssm_parameter.value
                ssm_value = self.param_handler.update_params(transform_params({key: value}))
                ssm_input_map.update(ssm_value)

            return ssm_input_map

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_state_machine_input_map(self, input_params, request_type='Create'):
        try:
            self.state_machine_event.update({'RequestType': request_type})
            self.state_machine_event.update({'ResourceProperties': input_params})

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _create_stack_set_state_machine_input_map(self, stack_set_name, template_url, parameters,
                                                  account_list, regions_list, ssm_map):
        input_params = {}
        input_params.update({'StackSetName': sanitize(stack_set_name)})
        input_params.update({'TemplateURL': template_url})
        input_params.update({'Parameters': parameters})
        input_params.update({'Capabilities': CAPABILITIES})

        if len(account_list) > 0:
            input_params.update({'AccountList': account_list})
            if len(regions_list) > 0:
                input_params.update({'RegionList': regions_list})
            else:
                input_params.update({'RegionList': [self.manifest.region]})
        else:
            input_params.update({'AccountList': ''})
            input_params.update({'RegionList': ''})

        if ssm_map is not None:
            input_params.update({'SSMParameters': ssm_map})

        self._create_state_machine_input_map(input_params)

    def _populate_ssm_params(self):
        try:
            # The scenario is if you have one core resource that exports output from CFN stack to SSM parameter
            # and then the next core resource reads the SSM parameter as input,
            # then it has to wait for the first core resource to
            # finish; read the SSM parameters and use its value as input for second core resource's input for SM
            # Get the parameters for CFN template from self.state_machine_event
            logger.debug("Populating SSM parameter values for SM input: {}".format(self.state_machine_event))
            params = self.state_machine_event.get('ResourceProperties').get('Parameters', {})
            # First transform it from {name: value} to [{'ParameterKey': name}, {'ParameterValue': value}]
            # then replace the SSM parameter names with its values
            sm_params = self.param_handler.update_params(transform_params(params))
            # Put it back into the self.state_machine_event
            self.state_machine_event.get('ResourceProperties').update({'Parameters': sm_params})
            logger.debug("Done populating SSM parameter values for SM input: {}".format(self.state_machine_event))

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _compare_template_and_params(self):
        try:
            stack_name = self.state_machine_event.get('ResourceProperties').get('StackSetName', '')
            flag = False
            if stack_name:
                stack_set = StackSet(self.logger)
                describe_response = stack_set.describe_stack_set(stack_name)
                if describe_response is not None:
                    self.logger.info("Found existing stack set.")

                    self.logger.info("Checking the status of last stack set operation on {}".format(stack_name))
                    response = stack_set.list_stack_set_operations(StackSetName=stack_name,
                                                                   MaxResults=1)

                    if response:
                        if response.get('Summaries'):
                            for instance in response.get('Summaries'):
                                self.logger.info("Status of last stack set operation : {}"
                                                 .format(instance.get('Status')))
                                if instance.get('Status') != 'SUCCEEDED':
                                    self.logger.info("The last stack operation did not succeed. "
                                                     "Triggering Update StackSet for {}".format(stack_name))
                                    return False

                    self.logger.info("Comparing the template of the StackSet: {} with local copy of template"
                                     .format(stack_name))

                    template_http_url = self.state_machine_event.get('ResourceProperties').get('TemplateURL', '')
                    if template_http_url:
                        template_s3_url = convert_http_url_to_s3_url(template_http_url)
                        local_template_file = download_remote_file(self.logger, template_s3_url)
                    else:
                        self.logger.error("TemplateURL in state machine input is empty. Check state_machine_event:{}"
                                          .format(self.state_machine_event))
                        return False

                    cfn_template_file = tempfile.mkstemp()[1]
                    with open(cfn_template_file, "w") as f:
                        f.write(describe_response.get('StackSet').get('TemplateBody'))

                    template_compare = filecmp.cmp(local_template_file, cfn_template_file, False)
                    self.logger.info("Comparing the parameters of the StackSet: {} "
                                     "with local copy of JSON parameters file".format(stack_name))

                    params_compare = True
                    params = self.state_machine_event.get('ResourceProperties').get('Parameters', {})
                    if template_compare:
                        cfn_params = reverse_transform_params(describe_response.get('StackSet').get('Parameters'))
                        for key, value in params.items():
                            if cfn_params.get(key, '') == value:
                                pass
                            else:
                                params_compare = False
                                break

                    self.logger.info("template_compare={}".format(template_compare))
                    self.logger.info("params_compare={}".format(params_compare))
                    if template_compare and params_compare:
                        account_list = self.state_machine_event.get('ResourceProperties').get("AccountList", [])
                        if account_list:
                            self.logger.info("Comparing the Stack Instances Account & Regions for StackSet: {}"
                                             .format(stack_name))
                            expected_region_list = set(self.state_machine_event.get('ResourceProperties').get("RegionList", []))

                            # iterator over accounts in event account list
                            for account in account_list:
                                actual_region_list = set()

                                self.logger.info("### Listing the Stack Instances for StackSet: {} and Account: {} ###"
                                                 .format(stack_name, account))
                                stack_instance_list = stack_set.list_stack_instances_per_account(stack_name, account)

                                self.logger.info(stack_instance_list)

                                if stack_instance_list:
                                    for instance in stack_instance_list:
                                        if instance.get('Status').upper() == 'CURRENT':
                                            actual_region_list.add(instance.get('Region'))
                                        else:
                                            self.logger.info("Found at least one of the Stack Instances in {} state."
                                                             " Triggering Update StackSet for {}"
                                                             .format(instance.get('Status'),
                                                                     stack_name))
                                            return False
                                else:
                                    self.logger.info("Found no stack instances in account: {}, "
                                                     "Updating StackSet: {}".format(account, stack_name))
                                    # # move the account id to index 0
                                    # newindex = 0
                                    # oldindex = self.state_machine_event.get('ResourceProperties').get("AccountList").index(account)
                                    # self.state_machine_event.get('ResourceProperties').get("AccountList").insert(newindex, self.state_machine_event.get('ResourceProperties').get("AccountList").pop(oldindex))
                                    return False

                                if expected_region_list.issubset(actual_region_list):
                                    self.logger.info("Found expected regions : {} in deployed stack instances : {},"
                                                     " so skipping Update StackSet for {}"
                                                     .format(expected_region_list,
                                                             actual_region_list,
                                                             stack_name))
                                    flag = True
                        else:
                            self.logger.info("Found no changes in template & parameters, "
                                             "so skipping Update StackSet for {}".format(stack_name))
                            flag = True
            return flag
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def state_machine_failed(self, status, failed_execution_list):
        error = " StackSet State Machine Execution(s) Failed. Navigate to the AWS Step Functions console and" \
                " review the following State Machine Executions. ARN List: {}".format(failed_execution_list)
        if status == 'FAILED':
            logger.error(100 * '*')
            logger.error(error)
            logger.error(100 * '*')
            sys.exit(1)

    def _run_or_queue_state_machine(self, stackset_name):
        try:
            logger.info("State machine Input: {}".format(self.state_machine_event))
            exec_name = "%s-%s-%s" % (self.state_machine_event.get('RequestType'), trim_length(stackset_name.replace(" ", ""), 50),
                                      time.strftime("%Y-%m-%dT%H-%M-%S"))
            # If Sequential, wait for the SM to be executed before kicking of the next one
            if self.sequential_flag:
                self.logger.info(" > > > > > >  Running Sequential Mode. > > > > > >")
                self._populate_ssm_params()
                if self._compare_template_and_params():
                    return
                else:
                    sm_exec_arn = self.state_machine.trigger_state_machine(self.sm_arn_stackset, self.state_machine_event, exec_name)
                    self.list_sm_exec_arns.append(sm_exec_arn)
                    status, failed_execution_list = self.monitor_state_machines_execution_status()
                    if status == 'FAILED':
                        self.state_machine_failed(status, failed_execution_list)
                    else:
                        self.logger.info("State Machine execution completed. Starting next execution...")
            # Else if Parallel, execute all SM at regular interval of wait_time
            else:
                self.logger.info(" | | | | | |  Running Parallel Mode. | | | | | |")
                # RUNS Parallel, execute all SM at regular interval of wait_time
                self._populate_ssm_params()
                # if the stackset comparision is matches - skip SM execution
                if self._compare_template_and_params():
                    return
                else: # if False execution SM
                    sm_exec_arn = self.state_machine.trigger_state_machine(self.sm_arn_stackset, self.state_machine_event, exec_name)
                time.sleep(int(wait_time))  # Sleeping for sometime
                self.list_sm_exec_arns.append(sm_exec_arn)
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _deploy_resource(self, resource, account_list):
        try:
            template_full_path = self._stage_template(resource.template_file)
            params = {}
            if resource.parameter_file:
                if len(resource.regions) > 0:
                    params = self._load_params(resource.parameter_file, account_list, resource.regions[0])
                else:
                    params = self._load_params(resource.parameter_file, account_list, self.manifest.region)

            ssm_map = self._create_ssm_input_map(resource.ssm_parameters)

            # Deploying Core resource Stack Set
            stack_name = "CustomControlTower-{}".format(resource.name)
            self._create_stack_set_state_machine_input_map(stack_name, template_full_path,
                                                                      params, account_list, resource.regions, ssm_map)


            self.logger.info(" >>> State Machine Input >>>")
            self.logger.info(self.state_machine_event)

            self._run_or_queue_state_machine(stack_name)
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def _get_root_id(self, org):
        response = org.list_roots()
        self.logger.info("Response: List Roots")
        self.logger.info(response)
        return response['Roots'][0].get('Id')

    def _list_ou_for_parent(self, org, parent_id):
        _ou_list = org.list_organizational_units_for_parent(parent_id)
        self.logger.info("Print Organizational Units List under {}".format(parent_id))
        self.logger.info(_ou_list)
        return _ou_list

    def _get_accounts_in_ou(self, org, ou_id_list):
        _accounts_in_ou = []
        accounts_in_all_ous = []
        ou_id_to_account_map = {}

        for _ou_id in ou_id_list:
            _account_list = org.list_accounts_for_parent(_ou_id)
            for _account in _account_list:
                # filter ACTIVE and CREATED accounts
                if _account.get('Status') == "ACTIVE":
                    # create a list of accounts in OU
                    accounts_in_all_ous.append(_account.get('Id'))
                    _accounts_in_ou.append(_account.get('Id'))

            # create a map of accounts for each ou
            self.logger.info("Creating Key:Value Mapping - OU ID: {} ; Account List: {}"
                             .format(_ou_id, _accounts_in_ou))
            ou_id_to_account_map.update({_ou_id: _accounts_in_ou})
            self.logger.info(ou_id_to_account_map)

            # reset list of accounts in the OU
            _accounts_in_ou = []

        self.logger.info("All accounts in OU List: {}".format(accounts_in_all_ous))
        self.logger.info("OU to Account ID mapping")
        self.logger.info(ou_id_to_account_map)
        return accounts_in_all_ous, ou_id_to_account_map

    def _get_ou_ids(self, org):
        # for each OU get list of account
        # get root id
        root_id = self._get_root_id(org)

        # get OUs under the Org root
        ou_list_at_root_level = self._list_ou_for_parent(org, root_id)

        ou_id_list = []
        _ou_name_to_id_map = {}
        _all_ou_ids = []

        for ou_at_root_level in ou_list_at_root_level:
            # build list of all the OU IDs under Org root
            _all_ou_ids.append(ou_at_root_level.get('Id'))
            # build a list of ou id
            _ou_name_to_id_map.update({ou_at_root_level.get('Name'): ou_at_root_level.get('Id')})

        self.logger.info("Print OU Name to OU ID Map")
        self.logger.info(_ou_name_to_id_map)

        # return:
        # 1. OU IDs of the OUs in the manifest
        # 2. Account IDs in OUs in the manifest
        # 3. Account IDs in all the OUs in the manifest
        return _all_ou_ids, _ou_name_to_id_map

    def get_account_for_name(self, org):
        # get all accounts in the organization
        account_list = org.get_accounts_in_org()
        #self.logger.info("Print Account List: {}".format(account_list))

        _name_to_account_map = {}
        for account in account_list:
            if account.get("Status") == "ACTIVE":
                _name_to_account_map.update({account.get("Name"): account.get("Id")})

        self.logger.info("Print Account Name > Account Mapping")
        self.logger.info(_name_to_account_map)

        return _name_to_account_map

    def get_organization_details(self):
        # > build dict
        # KEY: OU Name (in the manifest)
        # VALUE: OU ID (at root level)
        # > build list
        # all OU IDs under root
        org = Organizations(self.logger)
        all_ou_ids, ou_name_to_id_map = self._get_ou_ids(org)
        # > build list of all active accounts
        # use case: use to validate accounts in the manifest file.
        # > build dict
        # KEY: OU ID (for each OU at root level)
        # VALUE: get list of all active accounts
        # use case: map OU Name to account IDs
        accounts_in_all_ous, ou_id_to_account_map = self._get_accounts_in_ou(org, all_ou_ids)
        # build dict
        # KEY: email
        # VALUE: account id
        # use case: convert email in manifest to account ID for SM event
        name_to_account_map = self.get_account_for_name(org)
        return accounts_in_all_ous, ou_id_to_account_map, ou_name_to_id_map, name_to_account_map

    def start_stackset_sm(self):
        try:
            logger.info("Parsing Core Resources from {} file".format(self.manifest_file_path))
            count = 0

            accounts_in_all_ous, ou_id_to_account_map, ou_name_to_id_map, name_to_account_map = self.get_organization_details()

            for resource in self.manifest.cloudformation_resources:
                self.logger.info(">>>>>>>>> START : {} >>>>>>>>>".format(resource.name))
                # Handle scenario if 'deploy_to_ou' key does not exist in the resource
                try:
                    self.logger.info(resource.deploy_to_ou)
                except:
                    resource.deploy_to_ou = []

                # Handle scenario if 'deploy_to_account' key does not exist in the resource
                try:
                    self.logger.info(resource.deploy_to_account)
                except:
                    resource.deploy_to_account = []

                # find accounts for given ou name
                accounts_in_ou = []
                ou_ids_manifest = []

                # check if OU name list is empty
                if resource.deploy_to_ou:
                    # convert OU Name to OU IDs
                    for ou_name in resource.deploy_to_ou:
                        ou_id = [value for key, value in ou_name_to_id_map.items() if ou_name.lower() in key.lower()]
                        ou_ids_manifest.extend(ou_id)

                    # convert OU IDs to accounts
                    for ou_id, accounts in ou_id_to_account_map.items():
                        if ou_id in ou_ids_manifest:
                            accounts_in_ou.extend(accounts)

                    self.logger.info(">>> Accounts: {} in OUs: {}".format(accounts_in_ou, resource.deploy_to_ou))

                # convert account numbers to string type
                account_list = self._convert_list_values_to_string(resource.deploy_to_account)
                self.logger.info(">>>>>> ACCOUNT LIST")
                self.logger.info(account_list)

                # separate account id and emails
                name_list = []
                new_account_list = []
                self.logger.info(account_list)
                for item in account_list:
                    if item.isdigit() and len(item) == 12:  # if an actual account ID
                        new_account_list.append(item)
                        self.logger.info(new_account_list)
                    else:
                        name_list.append(item)
                        self.logger.info(name_list)

                # check if name list is empty
                if name_list:
                    # convert OU Name to OU IDs
                    for name in name_list:
                        name_account = [value for key, value in name_to_account_map.items() if
                                         name.lower() in key.lower()]
                        self.logger.info("%%%%%%% Name {} -  Account {}".format(name, name_account))
                        new_account_list.extend(name_account)

                # Remove account ids from the manifest that is not in the organization or not active
                sanitized_account_list = list(set(new_account_list).intersection(set(accounts_in_all_ous)))
                self.logger.info("Print Updated Manifest Account List")
                self.logger.info(sanitized_account_list)

                # merge account lists manifest account list and accounts under OUs in the manifest
                sanitized_account_list.extend(accounts_in_ou)
                sanitized_account_list = list(set(sanitized_account_list)) # remove duplicate accounts
                self.logger.info("Print merged account list - accounts in manifest + account under OU in manifest")
                self.logger.info(sanitized_account_list)

                if resource.deploy_method.lower() == 'stack_set':
                    self._deploy_resource(resource, sanitized_account_list)
                else:
                    raise Exception("Unsupported deploy_method: {} found for resource {} and Account: {} in Manifest"
                                    .format(resource.deploy_method, resource.name, sanitized_account_list))
                self.logger.info("<<<<<<<<< FINISH : {} <<<<<<<<<".format(resource.name))

                # Count number of stack sets deployed
                count += 1
            data = {"StackSetCount": str(count)}
            self.send.metrics(data)

            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    # return list of strings
    def _convert_list_values_to_string(self, _list):
        return list(map(str, _list))

    # monitor list of state machine executions
    def monitor_state_machines_execution_status(self):
        try:
            if self.list_sm_exec_arns:
                self.logger.info("Starting to monitor the SM Executions: {}".format(self.list_sm_exec_arns))
                final_status = 'RUNNING'

                while final_status == 'RUNNING':
                    for sm_exec_arn in self.list_sm_exec_arns:
                        status = self.state_machine.check_state_machine_status(sm_exec_arn)
                        if status == 'RUNNING':
                            final_status = 'RUNNING'
                            time.sleep(int(wait_time))
                            break
                        else:
                            final_status = 'COMPLETED'

                err_flag = False
                failed_sm_execution_list = []
                for sm_exec_arn in self.list_sm_exec_arns:
                    status = self.state_machine.check_state_machine_status(sm_exec_arn)
                    if status == 'SUCCEEDED':
                        continue
                    else:
                        failed_sm_execution_list.append(sm_exec_arn)
                        err_flag = True
                        continue

                if err_flag:
                    return 'FAILED', failed_sm_execution_list
                else:
                    return 'SUCCEEDED', ''
            else:
                self.logger.info("SM Execution List {} is empty, nothing to monitor.".format(self.list_sm_exec_arns))
                return None, []

        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise

    def trigger_stackset_state_machine(self):
        try:
            self.manifest = Manifest(self.manifest_file_path)
            self.start_stackset_sm()
            return
        except Exception as e:
            message = {'FILE': __file__.split('/')[-1], 'METHOD': inspect.stack()[0][3], 'EXCEPTION': str(e)}
            self.logger.exception(message)
            raise
    def train(self, evaluate_every=100):
        test_updates = self.test_updates()

        cumulative_train_metrics = Metrics.empty(mode="train")
        cumulative_evaluate_metrics = Metrics.empty(mode="eval")

        for update_info in self.train_and_evaluate_updates(
                evaluate_every=evaluate_every):
            if update_info.from_train:
                cumulative_train_metrics += update_info.metrics

                print(
                    f"{update_info.batch.ix} \t| {update_info.metrics.loss} \t= {update_info.model_loss} \t+ {update_info.fooling_loss} \t| {update_info.discriminator_loss}"
                )

                if update_info.batch.ix % 200 == 0:
                    with torch.no_grad():
                        predicted = update_info.decoded_inferred_texts[
                            0].replace('\n', ' ').strip('❟ ❟ ❟')
                        headline = update_info.batch.orig_headline[0].replace(
                            '\n', ' ').lower().strip()
                        text = update_info.batch.orig_text[0].replace(
                            '\n', ' ').lower().strip()
                        print(
                            f"{update_info.batch.ix}\n\nTEXT:\n{text} \n\nHEADLINE:\n{headline} \n\nPREDICTED SUMMARY:\n{predicted}"
                        )

                if update_info.batch.ix % 10 == 0:
                    self.writer.add_scalar('loss/train',
                                           cumulative_train_metrics.loss,
                                           update_info.batch.ix)

                    self.writer.add_scalar('model-loss/train',
                                           cumulative_train_metrics.loss,
                                           update_info.batch.ix)

                    self.writer.add_scalar('fooling-loss/train',
                                           cumulative_train_metrics.loss,
                                           update_info.batch.ix)

                    cumulative_train_metrics = Metrics.empty(mode="train")

            if update_info.from_evaluate:
                cumulative_evaluate_metrics += update_info.metrics

                if len(cumulative_evaluate_metrics) == 10:
                    with torch.no_grad():
                        predicted = update_info.decoded_inferred_texts[
                            0].replace('\n', ' ').strip('❟ ❟ ❟')
                        headline = update_info.batch.orig_headline[0].replace(
                            '\n', ' ').lower().strip()
                        text = update_info.batch.orig_text[0].replace(
                            '\n', ' ').lower().strip()
                        print(
                            f"{update_info.batch.ix}\n\nEVAL TEXT:\n{text} \n\nEVAL HEADLINE:\n{headline} \n\nEVAL PREDICTED SUMMARY:\n{predicted}"
                        )

                        self.writer.add_text(
                            'text/eval', text,
                            int(update_info.batch.ix / evaluate_every))

                        self.writer.add_text(
                            'headline/eval', headline,
                            int(update_info.batch.ix / evaluate_every))

                        self.writer.add_text(
                            'predicted/eval', predicted,
                            int(update_info.batch.ix / evaluate_every))

                    self.writer.add_scalar(
                        'rouge-1/eval',
                        cumulative_evaluate_metrics.rouge_score,
                        int(update_info.batch.ix / evaluate_every))

                    cumulative_evaluate_metrics = Metrics.empty(mode="eval")

                print(f"Eval: {update_info.metrics.loss}")

            if update_info.batch.ix % 600 == 0 and update_info.batch.ix != 0:
                print(f"Saving checkpoint")
                self.save_checkpoint()
Пример #21
0
from lib.metrics import Metrics
from lib.logger import Logger
from decimal import Decimal

log_level = 'info'
logger = Logger(loglevel=log_level)

send = Metrics(logger)


def test_backend_metrics():
    solution_id = 'SO_unit_test'
    data = {'key_string1': '2018-06-15',
            'key_string2': 'A1B2',
            'decimal': Decimal('5')
            }
    url = 'https://oszclq8tyh.execute-api.us-east-1.amazonaws.com/prod/generic'
    response = send.metrics(solution_id, data, url)
    logger.info(response)
    assert response == 200