Exemple #1
0
class OneLoginSAMLAuth(BaseAuthPlugin):
    name = 'OneLoginSAML'
    ns = NAMESPACE
    views = (SamlLoginRequest, SamlLoginConsumer, SamlLogoutRequest, SamlLogoutConsumer)
    options = (
        ConfigOption('strict', True, 'bool', 'Strict validation of SAML responses'),
        ConfigOption('debug', False, 'bool', 'Enable SAML debug mode'),
        ConfigOption('sp_entity_id', None, 'string', 'Service Provider Entity ID'),
        ConfigOption('sp_acs', None, 'string', 'Assertion Consumer endpoint'),
        ConfigOption('sp_sls', None, 'string', 'Single Logout Service endpoint'),
        ConfigOption('idp_entity_id', None, 'string', 'Identity Provider Entity ID'),
        ConfigOption('idp_ssos', None, 'string', 'Single Sign-On Service endpoint'),
        ConfigOption('idp_sls', None, 'string', 'Single Logout Service endpoint'),
        ConfigOption('idp_x509cert', None, 'string', 'Base64 encoded x509 certificate for SAML validation')
    )
    readonly = True
    login = {'url': '/auth/saml/login'}
    logout = '/auth/saml/logout'
Exemple #2
0
class RequiredTagsAuditor(BaseAuditor):
    name = 'Required Tags Compliance'
    ns = NS_AUDITOR_REQUIRED_TAGS
    interval = dbconfig.get('interval', ns, 30)
    tracking_enabled = dbconfig.get('enabled', NS_GOOGLE_ANALYTICS, False)
    tracking_id = dbconfig.get('tracking_id', NS_GOOGLE_ANALYTICS)
    confirm_shutdown = dbconfig.get('confirm_shutdown', ns, True)
    required_tags = []
    collect_only = None
    start_delay = 0
    options = (
        ConfigOption('action_taker_arn', '', 'string',
                     'Lambda entry point for action taker'),
        ConfigOption(
            'alert_settings', {
                '*': {
                    'alert': ['0 seconds', '3 weeks', '27 days'],
                    'stop': '4 weeks',
                    'remove': '12 weeks',
                    'scope': ['*']
                }
            }, 'json', 'Schedule for warning, stop and removal'),
        ConfigOption(
            'audit_scope',
            # max_items is 99 here, but is pulled during runtime and adjusted to the
            #  max number of available resources it doesn't really matter what we put
            {
                'enabled': [],
                'available':
                ['aws_ec2_instance', 'aws_s3_bucket', 'aws_rds_instance'],
                'max_items':
                99,
                'min_items':
                0
            },
            'choice',
            'Select the services you would like to audit'),
        ConfigOption('audit_ignore_tag', 'cinq_ignore', 'string',
                     'Do not audit resources have this tag set'),
        ConfigOption('always_send_email', True, 'bool',
                     'Send emails even in collect mode'),
        ConfigOption('collect_only', True, 'bool',
                     'Do not shutdown instances, only update caches'),
        ConfigOption(
            'confirm_shutdown', True, 'bool',
            'Require manual confirmation before shutting down instances'),
        ConfigOption('email_subject', 'Required tags audit notification',
                     'string', 'Subject of the email notification'),
        ConfigOption('enabled', False, 'bool',
                     'Enable the Required Tags auditor'),
        ConfigOption(
            'enable_delete_s3_buckets', True, 'bool',
            'Enable actual S3 bucket deletion. This might make you vulnerable to domain hijacking'
        ),
        ConfigOption('grace_period', 4, 'int',
                     'Only audit resources X minutes after being created'),
        ConfigOption('interval', 30, 'int',
                     'How often the auditor executes, in minutes.'),
        ConfigOption('partial_owner_match', True, 'bool',
                     'Allow partial matches of the Owner tag'),
        ConfigOption('permanent_recipient', [], 'array',
                     'List of email addresses to receive all alerts'),
        ConfigOption('required_tags', ['owner', 'accounting', 'name'], 'array',
                     'List of required tags'),
        ConfigOption(
            'lifecycle_expiration_days', 3, 'int',
            'How many days we should set in the bucket policy for non-empty S3 buckets removal'
        ),
        ConfigOption('gdpr_enabled', False, 'bool',
                     'Enable auditing for GDPR compliance'),
        ConfigOption('gdpr_accounts', [], 'array',
                     'List of accounts requiring GDPR compliance'),
        ConfigOption('gdpr_tag', 'gdpr_compliance', 'string',
                     'Name of GDPR compliance tag'),
        ConfigOption('gdpr_tag_values', ['pending', 'v1'], 'array',
                     'List of valid values for GDPR compliance tag'))

    def __init__(self):
        super().__init__()
        self.log.debug('Starting RequiredTags auditor')

        self.required_tags = dbconfig.get('required_tags', self.ns,
                                          ['owner', 'accounting', 'name'])
        self.collect_only = dbconfig.get('collect_only', self.ns, True)
        self.always_send_email = dbconfig.get('always_send_email', self.ns,
                                              False)
        self.permanent_emails = [{
            'type': 'email',
            'value': contact
        } for contact in dbconfig.get('permanent_recipient', self.ns, [])]
        self.email_subject = dbconfig.get('email_subject', self.ns,
                                          'Required tags audit notification')
        self.grace_period = dbconfig.get('grace_period', self.ns, 4)
        self.partial_owner_match = dbconfig.get('partial_owner_match', self.ns,
                                                True)
        self.audit_ignore_tag = dbconfig.get('audit_ignore_tag',
                                             NS_AUDITOR_REQUIRED_TAGS)
        self.alert_schedule = dbconfig.get('alert_settings',
                                           NS_AUDITOR_REQUIRED_TAGS)
        self.audited_types = dbconfig.get('audit_scope',
                                          NS_AUDITOR_REQUIRED_TAGS)['enabled']
        self.email_from_address = dbconfig.get('from_address', NS_EMAIL)
        self.resource_types = {
            resource_type.resource_type_id: resource_type.resource_type
            for resource_type in db.ResourceType.find()
        }
        self.gdpr_enabled = dbconfig.get('gdpr_enabled', self.ns, False)
        self.gdpr_accounts = dbconfig.get('gdpr_accounts', self.ns, [])
        self.gdpr_tag = dbconfig.get('gdpr_tag', self.ns, 'gdpr_compliance')
        self.gdpr_tag_values = dbconfig.get('gdpr_tag_values', self.ns,
                                            ['pending', 'v1'])
        self.resource_classes = {
            resource.resource_type: resource
            for resource in map(
                lambda plugin: plugin.load(),
                CINQ_PLUGINS['cloud_inquisitor.plugins.types']['plugins'])
        }

    def run(self, *args, **kwargs):
        known_issues, new_issues, fixed_issues = self.get_resources()
        known_issues += self.create_new_issues(new_issues)
        actions = [
            *[{
                'action': AuditActions.FIXED,
                'action_description': None,
                'last_alert': issue.last_alert,
                'issue': issue,
                'resource': issue.resource,
                'owners': self.get_contacts(issue),
                'notes': issue.notes,
                'missing_tags': issue.missing_tags
            } for issue in fixed_issues], *self.get_actions(known_issues)
        ]
        notifications = self.process_actions(actions)
        self.notify(notifications)

    def get_known_resources_missing_tags(self):
        non_compliant_resources = {}
        audited_types = dbconfig.get('audit_scope', NS_AUDITOR_REQUIRED_TAGS,
                                     {'enabled': []})['enabled']

        try:
            # resource_info is a tuple with the resource typename as [0] and the resource class as [1]
            resources = filter(
                lambda resource_info: resource_info[0] in audited_types,
                self.resource_classes.items())
            for resource_name, resource_class in resources:
                for resource_id, resource in resource_class.get_all().items():
                    missing_tags, notes = self.check_required_tags_compliance(
                        resource)
                    if missing_tags:
                        # Not really a get, it generates a new resource ID
                        issue_id = get_resource_id('reqtag', resource_id)
                        non_compliant_resources[issue_id] = {
                            'issue_id': issue_id,
                            'missing_tags': missing_tags,
                            'notes': notes,
                            'resource_id': resource_id,
                            'resource': resource
                        }
        finally:
            db.session.rollback()
        return non_compliant_resources

    def get_resources(self):
        found_issues = self.get_known_resources_missing_tags()
        existing_issues = RequiredTagsIssue.get_all().items()
        known_issues = []
        fixed_issues = []

        for existing_issue_id, existing_issue in existing_issues:
            # Check if the existing issue is still persists
            resource = found_issues.pop(existing_issue_id, None)
            if resource:
                if resource['missing_tags'] != existing_issue.missing_tags:
                    existing_issue.set_property('missing_tags',
                                                resource['missing_tags'])
                if resource['notes'] != existing_issue.notes:
                    existing_issue.set_property('notes', resource['notes'])
                db.session.add(existing_issue.issue)
                known_issues.append(existing_issue)
            else:
                fixed_issues.append(existing_issue)

        new_issues = {}
        for resource_id, resource in found_issues.items():
            try:
                if ((datetime.utcnow() -
                     resource['resource'].resource_creation_date
                     ).total_seconds() // 3600) >= self.grace_period:
                    new_issues[resource_id] = resource
            except Exception as ex:
                self.log.error(
                    'Failed to construct new issue {}, Error: {}'.format(
                        resource_id, ex))

        db.session.commit()
        return known_issues, new_issues, fixed_issues

    def create_new_issues(self, new_issues):
        try:
            for non_compliant_resource in new_issues.values():
                properties = {
                    'resource_id':
                    non_compliant_resource['resource_id'],
                    'account_id':
                    non_compliant_resource['resource'].account_id,
                    'location':
                    non_compliant_resource['resource'].location,
                    'created':
                    time.time(),
                    'last_alert':
                    '-1 seconds',
                    'missing_tags':
                    non_compliant_resource['missing_tags'],
                    'notes':
                    non_compliant_resource['notes'],
                    'resource_type':
                    non_compliant_resource['resource'].resource_name
                }
                issue = RequiredTagsIssue.create(
                    non_compliant_resource['issue_id'], properties=properties)
                self.log.info('Trying to add new issue / {} {}'.format(
                    properties['resource_id'], str(issue)))
                db.session.add(issue.issue)
                db.session.commit()
                yield issue
        except Exception as e:
            self.log.info('Could not add new issue / {}'.format(e))
        finally:
            db.session.rollback()

    def get_contacts(self, issue):
        """Returns a list of contacts for an issue

        Args:
            issue (:obj:`RequiredTagsIssue`): Issue record

        Returns:
            `list` of `dict`
        """
        # If the resources has been deleted, just return an empty list, to trigger issue deletion without notification
        if not issue.resource:
            return []

        account_contacts = issue.resource.account.contacts
        try:
            resource_owners = issue.resource.get_owner_emails()
            # Double check get_owner_emails for it's return value
            if type(resource_owners) is list:
                for resource_owner in resource_owners:
                    account_contacts.append({
                        'type': 'email',
                        'value': resource_owner
                    })
        except AttributeError:
            pass
        return account_contacts

    def get_actions(self, issues):
        """Returns a list of actions to executed

        Args:
            issues (`list` of :obj:`RequiredTagsIssue`): List of issues

        Returns:
            `list` of `dict`
        """
        actions = []
        try:
            for issue in issues:
                action_item = self.determine_action(issue)
                if action_item['action'] != AuditActions.IGNORE:
                    action_item['owners'] = self.get_contacts(issue)
                    actions.append(action_item)
        finally:
            db.session.rollback()
        return actions

    def determine_alert(self, action_schedule, issue_creation_time,
                        last_alert):
        """Determine if we need to trigger an alert

        Args:
            action_schedule (`list`): A list contains the alert schedule
            issue_creation_time (`int`): Time we create the issue
            last_alert (`str`): Time we sent the last alert

        Returns:
            (`None` or `str`)
            None if no alert should be sent. Otherwise return the alert we should send
        """
        issue_age = time.time() - issue_creation_time
        alert_schedule_lookup = {
            pytimeparse.parse(action_time): action_time
            for action_time in action_schedule
        }
        alert_schedule = sorted(alert_schedule_lookup.keys())
        last_alert_time = pytimeparse.parse(last_alert)

        for alert_time in alert_schedule:
            if last_alert_time < alert_time <= issue_age and last_alert_time != alert_time:
                return alert_schedule_lookup[alert_time]
        else:
            return None

    def determine_action(self, issue):
        """Determine the action we should take for the issue

        Args:
            issue: Issue to determine action for

        Returns:
             `dict`
        """
        resource_type = self.resource_types[issue.resource.resource_type_id]
        issue_alert_schedule = self.alert_schedule[resource_type] if \
            resource_type in self.alert_schedule \
            else self.alert_schedule['*']

        action_item = {
            'action':
            None,
            'action_description':
            None,
            'last_alert':
            issue.last_alert,
            'issue':
            issue,
            'resource':
            self.resource_classes[self.resource_types[
                issue.resource.resource_type_id]](issue.resource),
            'owners': [],
            'stop_after':
            issue_alert_schedule['stop'],
            'remove_after':
            issue_alert_schedule['remove'],
            'notes':
            issue.notes,
            'missing_tags':
            issue.missing_tags
        }

        time_elapsed = time.time() - issue.created
        stop_schedule = pytimeparse.parse(issue_alert_schedule['stop'])
        remove_schedule = pytimeparse.parse(issue_alert_schedule['remove'])

        if self.collect_only:
            action_item['action'] = AuditActions.IGNORE
        elif remove_schedule and time_elapsed >= remove_schedule:
            action_item['action'] = AuditActions.REMOVE
            action_item['action_description'] = 'Resource removed'
            action_item['last_alert'] = remove_schedule

        elif stop_schedule and time_elapsed >= stop_schedule:
            if issue.get_property('state').value == AuditActions.STOP:
                action_item['action'] = AuditActions.IGNORE
            else:
                action_item['action'] = AuditActions.STOP
                action_item['action_description'] = 'Resource stopped'
                action_item['last_alert'] = stop_schedule

        else:
            alert_selection = self.determine_alert(
                issue_alert_schedule['alert'],
                issue.get_property('created').value,
                issue.get_property('last_alert').value)
            if alert_selection:
                action_item['action'] = AuditActions.ALERT
                action_item['action_description'] = '{} alert'.format(
                    alert_selection)
                action_item['last_alert'] = alert_selection
            else:
                action_item['action'] = AuditActions.IGNORE

        return action_item

    def process_action(self, resource, action):
        return process_action(resource, action, self.ns)

    def process_actions(self, actions):
        """Process the actions we want to take

        Args:
            actions (`list`): List of actions we want to take

        Returns:
            `list` of notifications
        """
        notices = {}
        notification_contacts = {}
        for action in actions:
            resource = action['resource']
            action_status = ActionStatus.SUCCEED

            try:
                if action['action'] == AuditActions.REMOVE:
                    action_status = self.process_action(
                        resource, AuditActions.REMOVE)
                    if action_status == ActionStatus.SUCCEED:
                        db.session.delete(action['issue'].issue)

                elif action['action'] == AuditActions.STOP:
                    action_status = self.process_action(
                        resource, AuditActions.STOP)
                    if action_status == ActionStatus.SUCCEED:
                        action['issue'].update({
                            'missing_tags':
                            action['missing_tags'],
                            'notes':
                            action['notes'],
                            'last_alert':
                            action['last_alert'],
                            'state':
                            action['action']
                        })

                elif action['action'] == AuditActions.FIXED:
                    db.session.delete(action['issue'].issue)

                elif action['action'] == AuditActions.ALERT:
                    action['issue'].update({
                        'missing_tags':
                        action['missing_tags'],
                        'notes':
                        action['notes'],
                        'last_alert':
                        action['last_alert'],
                        'state':
                        action['action']
                    })

                db.session.commit()

                if action_status == ActionStatus.SUCCEED:
                    for owner in [
                            dict(t) for t in {
                                tuple(d.items())
                                for d in (action['owners'] +
                                          self.permanent_emails)
                            }
                    ]:
                        if owner['value'] not in notification_contacts:
                            contact = NotificationContact(type=owner['type'],
                                                          value=owner['value'])
                            notification_contacts[owner['value']] = contact
                            notices[contact] = {'fixed': [], 'not_fixed': []}
                        else:
                            contact = notification_contacts[owner['value']]

                        if action['action'] == AuditActions.FIXED:
                            notices[contact]['fixed'].append(action)
                        else:
                            notices[contact]['not_fixed'].append(action)
            except Exception as ex:
                self.log.exception(
                    'Unexpected error while processing resource {}/{}/{}/{}'.
                    format(action['resource'].account.account_name,
                           action['resource'].id, action['resource'], ex))

        return notices

    def validate_tag(self, key, value):
        """Check whether a tag value is valid

        Args:
            key: A tag key
            value: A tag value

        Returns:
            `(True or False)`
            A boolean indicating whether or not the value is valid
        """
        if key == 'owner':
            return validate_email(value, self.partial_owner_match)
        elif key == self.gdpr_tag:
            return value in self.gdpr_tag_values
        else:
            return True

    def check_required_tags_compliance(self, resource):
        """Check whether a resource is compliance

        Args:
            resource: A single resource

        Returns:
            `(list, list)`
            A tuple contains missing tags (if there were any) and notes
        """

        missing_tags = []
        notes = []
        resource_tags = {tag.key.lower(): tag.value for tag in resource.tags}

        # Do not audit this resource if it is not in the Account scope
        if resource.resource_type in self.alert_schedule:
            target_accounts = self.alert_schedule[
                resource.resource_type]['scope']
        else:
            target_accounts = self.alert_schedule['*']['scope']
        if not (resource.account.account_name in target_accounts
                or '*' in target_accounts):
            return missing_tags, notes

        # Do not audit this resource if the ignore tag was set
        if self.audit_ignore_tag.lower() in resource_tags:
            return missing_tags, notes

        required_tags = list(self.required_tags)

        # Add GDPR tag to required tags if the account must be GDPR compliant
        if self.gdpr_enabled and resource.account.account_name in self.gdpr_accounts:
            required_tags.append(self.gdpr_tag)
        '''
        # Do not audit this resource if it is still in grace period
        if (datetime.utcnow() - resource.resource_creation_date).total_seconds() // 3600 < self.grace_period:
            return missing_tags, notes
        '''

        # Check if the resource is missing required tags or has invalid tag values
        for key in [tag.lower() for tag in required_tags]:
            if key not in resource_tags:
                missing_tags.append(key)
            elif not self.validate_tag(key, resource_tags[key]):
                missing_tags.append(key)
                notes.append('{} tag is not valid'.format(key))

        if missing_tags and resource.resource_type == 'aws_rds_instance':
            notes.append('Instance name = {}'.format(resource.instance_name))

        return missing_tags, notes

    def notify(self, notices):
        """Send notifications to the recipients provided

        Args:
            notices (:obj:`dict` of `str`: `list`): A dictionary mapping notification messages to the recipient.

        Returns:
            `None`
        """
        tmpl_html = get_template('required_tags_notice.html')
        tmpl_text = get_template('required_tags_notice.txt')
        for recipient, data in list(notices.items()):
            body_html = tmpl_html.render(data=data)
            body_text = tmpl_text.render(data=data)

            send_notification(subsystem=self.ns,
                              recipients=[recipient],
                              subject=self.email_subject,
                              body_html=body_html,
                              body_text=body_text)
Exemple #3
0
class VPCFlowLogsAuditor(BaseAuditor):
    name = 'VPC Flow Log Compliance'
    ns = NS_AUDITOR_VPC_FLOW_LOGS
    interval = dbconfig.get('interval', ns, 60)
    role_name = dbconfig.get('role_name', ns, 'VpcFlowLogsRole')
    start_delay = 0
    options = (ConfigOption('enabled', False, 'bool',
                            'Enable the VPC Flow Logs auditor'),
               ConfigOption('interval', 60, 'int', 'Run frequency in minutes'),
               ConfigOption('role_name', 'VpcFlowLogsRole', 'str',
                            'Name of IAM Role used for VPC Flow Logs'))

    def __init__(self):
        super().__init__()
        self.session = None

    def run(self):
        """Main entry point for the auditor worker.

        Returns:
            `None`
        """
        # Loop through all accounts that are marked as enabled
        accounts = list(AWSAccount.get_all(include_disabled=False).values())
        for account in accounts:
            self.log.debug('Updating VPC Flow Logs for {}'.format(account))

            self.session = get_aws_session(account)
            role_arn = self.confirm_iam_role(account)
            # region specific
            for aws_region in AWS_REGIONS:
                try:
                    vpc_list = VPC.get_all(account, aws_region).values()
                    need_vpc_flow_logs = [
                        x for x in vpc_list
                        if x.vpc_flow_logs_status != 'ACTIVE'
                    ]

                    for vpc in need_vpc_flow_logs:
                        if self.confirm_cw_log(account, aws_region, vpc.id):
                            self.create_vpc_flow_logs(account, aws_region,
                                                      vpc.id, role_arn)
                        else:
                            self.log.info(
                                'Failed to confirm log group for {}/{}'.format(
                                    account, aws_region))

                except Exception:
                    self.log.exception(
                        'Failed processing VPCs for {}/{}.'.format(
                            account, aws_region))

            db.session.commit()

    @retry
    def confirm_iam_role(self, account):
        """Return the ARN of the IAM Role on the provided account as a string. Returns an `IAMRole` object from boto3

        Args:
            account (:obj:`Account`): Account where to locate the role

        Returns:
            :obj:`IAMRole`
        """
        try:
            iam = self.session.client('iam')
            rolearn = iam.get_role(RoleName=self.role_name)['Role']['Arn']
            return rolearn

        except ClientError as e:
            if e.response['Error']['Code'] == 'NoSuchEntity':
                self.create_iam_role(account)
            else:
                raise

        except Exception as e:
            self.log.exception(
                'Failed validating IAM role for VPC Flow Log Auditing for {}'.
                format(e))

    @retry
    def create_iam_role(self, account):
        """Create a new IAM role. Returns the ARN of the newly created role

        Args:
            account (:obj:`Account`): Account where to create the IAM role

        Returns:
            `str`
        """
        try:
            iam = self.session.client('iam')
            trust = get_template('vpc_flow_logs_iam_role_trust.json').render()
            policy = get_template('vpc_flow_logs_role_policy.json').render()

            newrole = iam.create_role(
                Path='/',
                RoleName=self.role_name,
                AssumeRolePolicyDocument=trust)['Role']['Arn']

            # Attach an inline policy to the role to avoid conflicts or hitting the Managed Policy Limit
            iam.put_role_policy(RoleName=self.role_name,
                                PolicyName='VpcFlowPolicy',
                                PolicyDocument=policy)

            self.log.debug('Created VPC Flow Logs role & policy for {}'.format(
                account.account_name))
            auditlog(event='vpc_flow_logs.create_iam_role',
                     actor=self.ns,
                     data={
                         'account': account.account_name,
                         'roleName': self.role_name,
                         'trustRelationship': trust,
                         'inlinePolicy': policy
                     })
            return newrole

        except Exception:
            self.log.exception(
                'Failed creating the VPC Flow Logs role for {}.'.format(
                    account))

    @retry
    def confirm_cw_log(self, account, region, vpcname):
        """Create a new CloudWatch log group based on the VPC Name if none exists. Returns `True` if succesful

        Args:
            account (:obj:`Account`): Account to create the log group in
            region (`str`): Region to create the log group in
            vpcname (`str`): Name of the VPC the log group is fow

        Returns:
            `bool`
        """
        try:
            cw = self.session.client('logs', region)
            token = None
            log_groups = []
            while True:
                result = cw.describe_log_groups(
                ) if not token else cw.describe_log_groups(nextToken=token)
                token = result.get('nextToken')
                log_groups.extend(
                    [x['logGroupName'] for x in result.get('logGroups', [])])

                if not token:
                    break

            if vpcname not in log_groups:
                cw.create_log_group(logGroupName=vpcname)

                cw_vpc = VPC.get(vpcname)
                cw_vpc.set_property('vpc_flow_logs_log_group', vpcname)

                self.log.info('Created log group {}/{}/{}'.format(
                    account.account_name, region, vpcname))
                auditlog(event='vpc_flow_logs.create_cw_log_group',
                         actor=self.ns,
                         data={
                             'account': account.account_name,
                             'region': region,
                             'log_group_name': vpcname,
                             'vpc': vpcname
                         })
            return True

        except Exception:
            self.log.exception(
                'Failed creating log group for {}/{}/{}.'.format(
                    account, region, vpcname))

    @retry
    def create_vpc_flow_logs(self, account, region, vpc_id, iam_role_arn):
        """Create a new VPC Flow log

        Args:
            account (:obj:`Account`): Account to create the flow in
            region (`str`): Region to create the flow in
            vpc_id (`str`): ID of the VPC to create the flow for
            iam_role_arn (`str`): ARN of the IAM role used to post logs to the log group

        Returns:
            `None`
        """
        try:
            flow = self.session.client('ec2', region)
            flow.create_flow_logs(ResourceIds=[vpc_id],
                                  ResourceType='VPC',
                                  TrafficType='ALL',
                                  LogGroupName=vpc_id,
                                  DeliverLogsPermissionArn=iam_role_arn)
            fvpc = VPC.get(vpc_id)
            fvpc.set_property('vpc_flow_logs_status', 'ACTIVE')

            self.log.info('Enabled VPC Logging {}/{}/{}'.format(
                account, region, vpc_id))
            auditlog(event='vpc_flow_logs.create_vpc_flow',
                     actor=self.ns,
                     data={
                         'account': account.account_name,
                         'region': region,
                         'vpcId': vpc_id,
                         'arn': iam_role_arn
                     })
        except Exception:
            self.log.exception(
                'Failed creating VPC Flow Logs for {}/{}/{}.'.format(
                    account, region, vpc_id))
Exemple #4
0
class EmailNotifier(BaseNotifier):
    name = 'Email Notifier'
    ns = NS_EMAIL
    notifier_type = 'email'
    validation = RGX_EMAIL_VALIDATION_PATTERN
    options = (
        ConfigOption('enabled', True, 'bool',
                     'Enable the Email notifier plugin'),
        ConfigOption('from_address', '*****@*****.**', 'string',
                     'Sender address for emails'),
        ConfigOption('method', 'ses', 'string',
                     'EMail sending method, either ses or smtp'),
        ConfigOption(
            'from_arn', '', 'string',
            'If using cross-account SES, this is the "From ARN", otherwise leave blank'
        ),
        ConfigOption(
            'return_path_arn', '', 'string',
            'If using cross-account SES, this is the "Return Path ARN", otherwise leave blank'
        ),
        ConfigOption(
            'source_arn', '', 'string',
            'If using cross-account SES, this is the "Source ARN", otherwise leave blank'
        ),
        ConfigOption('ses_region', 'us-west-2', 'string',
                     'Which SES region to send emails from'),
        ConfigOption('smtp_server', 'localhost', 'string',
                     'Address of the SMTP server to use'),
        ConfigOption('smtp_port', 25, 'int', 'Port for the SMTP server'),
        ConfigOption(
            'smtp_username', '', 'string',
            'Username for SMTP authentication. Leave blank for no authentication'
        ),
        ConfigOption(
            'smtp_password', '', 'string',
            'Password for SMTP authentication. Leave blank for no authentication'
        ),
        ConfigOption('smtp_tls', False, 'bool', 'Use TLS for sending emails'),
    )

    def __init__(self):
        super().__init__()
        self.sender = self.dbconfig.get('from_address', NS_EMAIL)

    def notify(self, subsystem, recipient, subject, body_html, body_text):
        """Method to send a notification. A plugin may use only part of the information, but all fields are required.

        Args:
            subsystem (`str`): Name of the subsystem originating the notification
            recipient (`str`): Recipient email address
            subject (`str`): Subject / title of the notification
            body_html (`str)`: HTML formatted version of the message
            body_text (`str`): Text formatted version of the message

        Returns:
            `None`
        """
        if not re.match(RGX_EMAIL_VALIDATION_PATTERN, recipient, re.I):
            raise ValueError('Invalid recipient provided')

        email = Email()
        email.timestamp = datetime.now()
        email.subsystem = subsystem
        email.sender = self.sender
        email.recipients = recipient
        email.subject = subject
        email.uuid = uuid.uuid4()
        email.message_html = body_html
        email.message_text = body_text

        method = dbconfig.get('method', NS_EMAIL, 'ses')
        try:
            if method == 'ses':
                self.__send_ses_email([recipient], subject, body_html,
                                      body_text)

            elif method == 'smtp':
                self.__send_smtp_email([recipient], subject, body_html,
                                       body_text)

            else:
                raise ValueError('Invalid email method: {}'.format(method))

            db.session.add(email)
            db.session.commit()
        except Exception as ex:
            raise EmailSendError(ex)

    def __send_ses_email(self, recipients, subject, body_html, body_text):
        """Send an email using SES

        Args:
            recipients (`1ist` of `str`): List of recipient email addresses
            subject (str): Subject of the email
            body_html (str): HTML body of the email
            body_text (str): Text body of the email

        Returns:
            `None`
        """
        source_arn = dbconfig.get('source_arn', NS_EMAIL)
        return_arn = dbconfig.get('return_path_arn', NS_EMAIL)

        session = get_local_aws_session()
        ses = session.client('ses',
                             region_name=dbconfig.get('ses_region', NS_EMAIL,
                                                      'us-west-2'))

        body = {}
        if body_html:
            body['Html'] = {'Data': body_html}
        if body_text:
            body['Text'] = {'Data': body_text}

        ses_options = {
            'Source': self.sender,
            'Destination': {
                'ToAddresses': recipients
            },
            'Message': {
                'Subject': {
                    'Data': subject
                },
                'Body': body
            }
        }

        # Set SES options if needed
        if source_arn and return_arn:
            ses_options.update({
                'SourceArn': source_arn,
                'ReturnPathArn': return_arn
            })

        ses.send_email(**ses_options)

    def __send_smtp_email(self, recipients, subject, html_body, text_body):
        """Send an email using SMTP

        Args:
            recipients (`list` of `str`): List of recipient email addresses
            subject (str): Subject of the email
            html_body (str): HTML body of the email
            text_body (str): Text body of the email

        Returns:
            `None`
        """
        smtp = smtplib.SMTP(dbconfig.get('smtp_server', NS_EMAIL, 'localhost'),
                            dbconfig.get('smtp_port', NS_EMAIL, 25))
        source_arn = dbconfig.get('source_arn', NS_EMAIL)
        return_arn = dbconfig.get('return_path_arn', NS_EMAIL)
        from_arn = dbconfig.get('from_arn', NS_EMAIL)
        msg = MIMEMultipart('alternative')

        # Set SES options if needed
        if source_arn and from_arn and return_arn:
            msg['X-SES-SOURCE-ARN'] = source_arn
            msg['X-SES-FROM-ARN'] = from_arn
            msg['X-SES-RETURN-PATH-ARN'] = return_arn

        msg['Subject'] = subject
        msg['To'] = ','.join(recipients)
        msg['From'] = self.sender

        # Check body types to avoid exceptions
        if html_body:
            html_part = MIMEText(html_body, 'html')
            msg.attach(html_part)
        if text_body:
            text_part = MIMEText(text_body, 'plain')
            msg.attach(text_part)

        # TLS if needed
        if dbconfig.get('smtp_tls', NS_EMAIL, False):
            smtp.starttls()

        # Login if needed
        username = dbconfig.get('smtp_username', NS_EMAIL)
        password = dbconfig.get('smtp_password', NS_EMAIL)
        if username and password:
            smtp.login(username, password)

        smtp.sendmail(self.sender, recipients, msg.as_string())
        smtp.quit()
Exemple #5
0
class SlackNotifier(BaseNotifier):
    name = 'Slack Notifier'
    ns = NS_SLACK
    enabled = dbconfig.get('enabled', ns, True)
    options = (
        ConfigOption('enabled', False, 'bool', 'Enable the Slack notifier plugin'),
        ConfigOption('api_key', '', 'string', 'API token for the slack notifications'),
        ConfigOption('bot_name', 'Inquisitor', 'string', 'Name of the bot in Slack'),
    )

    def __init__(self, api_key=None):
        super().__init__()

        if not self.enabled:
            raise SlackError('Slack messaging is disabled')

        self.slack_client = SlackClient(api_key or dbconfig.get('api_key', self.ns))
        self.bot_name = dbconfig.get('bot_name', self.ns, 'Inquisitor')

        if not self.__check():
            raise SlackError('Invalid API KEY!')

    def __check(self):
        try:
            response = self.slack_client.api_call('auth.test')
            return response['ok']
        except Exception:
            return False

    def __get_user_id(self, email):
        response = self.slack_client.api_call('users.list')
        try:
            if not response['ok']:
                raise SlackError('Failed to list Slack users!')
            for item in response['members']:
                _profile = item['profile']
                if _profile.get('email', None) == email:
                    return item['id']
            else:
                SlackError('Failed to get user from Slack!')
        except Exception as ex:
            raise SlackError(ex)

    def __get_channel_for_user(self, user_email):
        user_id = self.__get_user_id(user_email)
        try:
            response = self.slack_client.api_call('im.open', user=user_id)
            if not response['ok']:
                raise SlackError('Failed to get channel for user!')

            return response['channel']['id']
        except Exception as ex:
            raise SlackError(ex)

    def _send_message(self, target_type, target, message):
        if target_type == 'user':
            channel = self.__get_channel_for_user(target)
        else:
            channel = target

        result = self.slack_client.api_call(
            'chat.postMessage',
            channel=channel,
            text=message,
            username=self.bot_name
        )
        if not result.get('ok', False):
            raise SlackError('Failed to send message: {}'.format(result['error']))

    @staticmethod
    def send_message(contacts, message):
        """List of contacts the send the message to. You can send messages either to channels and private groups by using
        the following formats

        #channel-name
        @username-direct-message

        If the channel is the name of a private group / channel, you must first invite the bot to the channel to ensure it
        is allowed to send messages to the group.

        Returns true if the message was sent, else `False`

        Args:
            contacts (:obj:`list` of `str`,`str`): List of contacts
            message (str): Message to send

        Returns:
            `bool`
        """
        slack_api_object = SlackNotifier()

        if type(contacts) == str:
            contacts = [contacts]

        for contact in contacts:
            if contact.startswith('#'):
                target_type = 'channel'

            elif '@' in contact:
                target_type = 'user'

            else:
                raise SlackError('Unrecognized contact {}'.format(contact))

            slack_api_object._send_message(
                target_type=target_type,
                target=contact,
                message=message
            )

            return True
Exemple #6
0
    def run(self, **kwargs):
        # Setup the base application settings
        defaults = [
            {
                'prefix':
                'default',
                'name':
                'Default',
                'sort_order':
                0,
                'options': [
                    ConfigOption('debug', False, 'bool',
                                 'Enable debug mode for flask'),
                    ConfigOption('session_expire_time', 43200, 'int',
                                 'Time in seconds before sessions expire'),
                    ConfigOption(
                        'role_name', 'cinq_role', 'string',
                        'Role name Cloud Inquisitor will use in each account'),
                    ConfigOption(
                        'ignored_aws_regions_regexp', '(^cn-|GLOBAL|-gov)',
                        'string',
                        'A regular expression used to filter out regions from the AWS static data'
                    ),
                    ConfigOption(name='auth_system',
                                 default_value={
                                     'enabled': ['Local Authentication'],
                                     'available': ['Local Authentication'],
                                     'max_items': 1,
                                     'min_items': 1
                                 },
                                 type='choice',
                                 description='Enabled authentication module'),
                    ConfigOption('scheduler', 'StandaloneScheduler', 'string',
                                 'Default scheduler module'),
                    ConfigOption(
                        'jwt_key_file_path', 'ssl/private.key', 'string',
                        'Path to the private key used to encrypt JWT session tokens. Can be relative to the '
                        'folder containing the configuration file, or absolute path'
                    )
                ],
            },
            {
                'prefix':
                'log',
                'name':
                'Logging',
                'sort_order':
                1,
                'options': [
                    ConfigOption('log_level', 'INFO', 'string', 'Log level'),
                    ConfigOption(
                        'enable_syslog_forwarding', False, 'bool',
                        'Enable forwarding logs to remote log collector'),
                    ConfigOption('remote_syslog_server_addr', '127.0.0.1',
                                 'string',
                                 'Address of the remote log collector'),
                    ConfigOption('remote_syslog_server_port', 514, 'string',
                                 'Port of the remote log collector'),
                    ConfigOption('log_keep_days', 31, 'int',
                                 'Delete log entries older than n days'),
                ],
            },
            {
                'prefix':
                'api',
                'name':
                'API',
                'sort_order':
                2,
                'options': [
                    ConfigOption('host', '127.0.0.1', 'string',
                                 'Host of the API server'),
                    ConfigOption('port', 5000, 'int',
                                 'Port of the API server'),
                    ConfigOption(
                        'workers', 6, 'int',
                        'Number of worker processes spawned for the API')
                ]
            },
        ]

        # Setup all the default base settings
        for data in defaults:
            nsobj = self.get_config_namespace(data['prefix'],
                                              data['name'],
                                              sort_order=data['sort_order'])
            for opt in data['options']:
                self.register_default_option(nsobj, opt)
            db.session.add(nsobj)
            db.session.commit()

        # Iterate over all of our plugins and setup their defaults
        for ptype, namespaces in list(PLUGIN_NAMESPACES.items()):
            for ns in namespaces:
                for ep in iter_entry_points(ns):
                    cls = ep.load()
                    if hasattr(cls, 'ns'):
                        ns_name = '{}: {}'.format(ptype.capitalize(), cls.name)
                        nsobj = self.get_config_namespace(cls.ns, ns_name)
                        if not isinstance(cls.options, abstractproperty):
                            if cls.options:
                                for opt in cls.options:
                                    self.register_default_option(nsobj, opt)
                        db.session.add(nsobj)
                        db.session.commit()

        # Create the default roles if they are missing
        self.add_default_roles()

        # If there are no accounts created, ask the user if he/she wants to create one now
        if not kwargs['headless_mode'] and not Account.query.first():
            if confirm(
                    'You have no accounts defined, do you wish to add the first account now?'
            ):
                self.init_account()
Exemple #7
0
class IAMAuditor(BaseAuditor):
    """Validate and apply IAM policies for AWS Accounts
    """
    name = 'IAM'
    ns = NS_AUDITOR_IAM
    interval = dbconfig.get('interval', ns, 30)
    start_delay = 0
    manage_roles = dbconfig.get('manage_roles', ns, True)
    git_policies = None
    cfg_roles = None
    aws_managed_policies = None
    options = (
        ConfigOption('enabled', False, 'bool',
                     'Enable the IAM roles and policy auditor'),
        ConfigOption('interval', 30, 'int',
                     'How often the auditor executes, in minutes'),
        ConfigOption('manage_roles', True, 'bool',
                     'Enable management of IAM roles'),
        ConfigOption(
            'roles', '{ }', 'json',
            'JSON document with roles to push to accounts. See documentation for examples'
        ),
        ConfigOption('delete_inline_policies', False, 'bool',
                     'Delete inline policies from existing roles'),
        ConfigOption('git_auth_token', 'CHANGE ME', 'string',
                     'API Auth token for Github'),
        ConfigOption('git_server', 'CHANGE ME', 'string',
                     'Address of the Github server'),
        ConfigOption('git_repo', 'CHANGE ME', 'string', 'Name of Github repo'),
        ConfigOption('git_no_ssl_verify', False, 'bool',
                     'Disable SSL verification of Github server'),
        ConfigOption('role_timeout', 8, 'int', 'AssumeRole timeout in hours'))

    def run(self, *args, **kwargs):
        """Iterate through all AWS accounts and apply roles and policies from Github

        Args:
            *args: Optional list of arguments
            **kwargs: Optional list of keyword arguments

        Returns:
            `None`
        """
        accounts = list(AWSAccount.get_all(include_disabled=False).values())
        self.manage_policies(accounts)

    def manage_policies(self, accounts):
        if not accounts:
            return

        self.git_policies = self.get_policies_from_git()
        self.manage_roles = self.dbconfig.get('manage_roles', self.ns, True)
        self.cfg_roles = self.dbconfig.get('roles', self.ns)
        self.aws_managed_policies = {
            policy['PolicyName']: policy
            for policy in self.get_policies_from_aws(
                get_aws_session(accounts[0]).client('iam'), 'AWS')
        }

        for account in accounts:
            try:
                if not account.ad_group_base:
                    self.log.info(
                        'Account {} does not have AD Group Base set, skipping'.
                        format(account.account_name))
                    continue

                # List all policies and roles from AWS, and generate a list of policies from Git
                sess = get_aws_session(account)
                iam = sess.client('iam')

                aws_roles = {
                    role['RoleName']: role
                    for role in self.get_roles(iam)
                }
                aws_policies = {
                    policy['PolicyName']: policy
                    for policy in self.get_policies_from_aws(iam)
                }

                account_policies = copy.deepcopy(self.git_policies['GLOBAL'])

                if account.account_name in self.git_policies:
                    for role in self.git_policies[account.account_name]:
                        account_policies.update(
                            self.git_policies[account.account_name][role])

                aws_policies.update(
                    self.check_policies(account, account_policies,
                                        aws_policies))
                self.check_roles(account, aws_policies, aws_roles)
            except Exception as exception:
                self.log.info(
                    'Unable to process account {}. Unhandled Exception {}'.
                    format(account.account_name, exception))

    @retry
    def check_policies(self, account, account_policies, aws_policies):
        """Iterate through the policies of a specific account and create or update the policy if its missing or
        does not match the policy documents from Git. Returns a dict of all the policies added to the account
        (does not include updated policies)

        Args:
            account (:obj:`Account`): Account to check policies for
            account_policies (`dict` of `str`: `dict`): A dictionary containing all the policies for the specific
            account
            aws_policies (`dict` of `str`: `dict`): A dictionary containing the non-AWS managed policies on the account

        Returns:
            :obj:`dict` of `str`: `str`
        """
        self.log.debug('Fetching policies for {}'.format(account.account_name))
        sess = get_aws_session(account)
        iam = sess.client('iam')
        added = {}

        for policyName, account_policy in account_policies.items():
            # policies pulled from github a likely bytes and need to be converted
            if isinstance(account_policy, bytes):
                account_policy = account_policy.decode('utf-8')

            # Using re.sub instead of format since format breaks on the curly braces of json
            gitpol = json.loads(
                re.sub(r'{AD_Group}', account.ad_group_base
                       or account.account_name, account_policy))

            if policyName in aws_policies:
                pol = aws_policies[policyName]
                awspol = iam.get_policy_version(
                    PolicyArn=pol['Arn'], VersionId=pol['DefaultVersionId']
                )['PolicyVersion']['Document']

                if awspol != gitpol:
                    self.log.warn(
                        'IAM Policy {} on {} does not match Git policy documents, updating'
                        .format(policyName, account.account_name))

                    self.create_policy(account,
                                       iam,
                                       json.dumps(gitpol, indent=4),
                                       policyName,
                                       arn=pol['Arn'])
                else:
                    self.log.debug('IAM Policy {} on {} is up to date'.format(
                        policyName, account.account_name))
            else:
                self.log.warn('IAM Policy {} is missing on {}'.format(
                    policyName, account.account_name))
                response = self.create_policy(account, iam, json.dumps(gitpol),
                                              policyName)
                added[policyName] = response['Policy']

        return added

    @retry
    def check_roles(self, account, aws_policies, aws_roles):
        """Iterate through the roles of a specific account and create or update the roles if they're missing or
        does not match the roles from Git.

        Args:
            account (:obj:`Account`): The account to check roles on
            aws_policies (:obj:`dict` of `str`: `dict`): A dictionary containing all the policies for the specific
            account
            aws_roles (:obj:`dict` of `str`: `dict`): A dictionary containing all the roles for the specific account

        Returns:
            `None`
        """
        self.log.debug('Checking roles for {}'.format(account.account_name))
        max_session_duration = self.dbconfig.get('role_timeout_in_hours',
                                                 self.ns, 8) * 60 * 60
        sess = get_aws_session(account)
        iam = sess.client('iam')

        # Build a list of default role policies and extra account specific role policies
        account_roles = copy.deepcopy(self.cfg_roles)
        if account.account_name in self.git_policies:
            for role in self.git_policies[account.account_name]:
                if role in account_roles:
                    account_roles[role]['policies'] += list(
                        self.git_policies[account.account_name][role].keys())

        for role_name, data in list(account_roles.items()):
            if role_name not in aws_roles:
                iam.create_role(Path='/',
                                RoleName=role_name,
                                AssumeRolePolicyDocument=json.dumps(
                                    data['trust'], indent=4),
                                MaxSessionDuration=max_session_duration)
                self.log.info('Created role {}/{}'.format(
                    account.account_name, role_name))
            else:
                try:
                    if aws_roles[role_name][
                            'MaxSessionDuration'] != max_session_duration:
                        iam.update_role(
                            RoleName=aws_roles[role_name]['RoleName'],
                            MaxSessionDuration=max_session_duration)
                        self.log.info(
                            'Adjusted MaxSessionDuration for role {} in account {} to {} seconds'
                            .format(role_name, account.account_name,
                                    max_session_duration))
                except ClientError:
                    self.log.exception(
                        'Unable to adjust MaxSessionDuration for role {} in account {}'
                        .format(role_name, account.account_name))

            aws_role_policies = [
                x['PolicyName'] for x in iam.list_attached_role_policies(
                    RoleName=role_name)['AttachedPolicies']
            ]
            aws_role_inline_policies = iam.list_role_policies(
                RoleName=role_name)['PolicyNames']
            cfg_role_policies = data['policies']

            missing_policies = list(
                set(cfg_role_policies) - set(aws_role_policies))
            extra_policies = list(
                set(aws_role_policies) - set(cfg_role_policies))

            if aws_role_inline_policies:
                self.log.info(
                    'IAM Role {} on {} has the following inline policies: {}'.
                    format(role_name, account.account_name,
                           ', '.join(aws_role_inline_policies)))

                if self.dbconfig.get('delete_inline_policies', self.ns,
                                     False) and self.manage_roles:
                    for policy in aws_role_inline_policies:
                        iam.delete_role_policy(RoleName=role_name,
                                               PolicyName=policy)
                        auditlog(
                            event='iam.check_roles.delete_inline_role_policy',
                            actor=self.ns,
                            data={
                                'account': account.account_name,
                                'roleName': role_name,
                                'policy': policy
                            })

            if missing_policies:
                self.log.info(
                    'IAM Role {} on {} is missing the following policies: {}'.
                    format(role_name, account.account_name,
                           ', '.join(missing_policies)))
                if self.manage_roles:
                    for policy in missing_policies:
                        iam.attach_role_policy(
                            RoleName=role_name,
                            PolicyArn=aws_policies[policy]['Arn'])
                        auditlog(event='iam.check_roles.attach_role_policy',
                                 actor=self.ns,
                                 data={
                                     'account': account.account_name,
                                     'roleName': role_name,
                                     'policyArn': aws_policies[policy]['Arn']
                                 })

            if extra_policies:
                self.log.info(
                    'IAM Role {} on {} has the following extra policies applied: {}'
                    .format(role_name, account.account_name,
                            ', '.join(extra_policies)))

                for policy in extra_policies:
                    if policy in aws_policies:
                        polArn = aws_policies[policy]['Arn']
                    elif policy in self.aws_managed_policies:
                        polArn = self.aws_managed_policies[policy]['Arn']
                    else:
                        polArn = None
                        self.log.info(
                            'IAM Role {} on {} has an unknown policy attached: {}'
                            .format(role_name, account.account_name, policy))

                    if self.manage_roles and polArn:
                        iam.detach_role_policy(RoleName=role_name,
                                               PolicyArn=polArn)
                        auditlog(event='iam.check_roles.detach_role_policy',
                                 actor=self.ns,
                                 data={
                                     'account': account.account_name,
                                     'roleName': role_name,
                                     'policyArn': polArn
                                 })

    def get_policies_from_git(self):
        """Retrieve policies from the Git repo. Returns a dictionary containing all the roles and policies

        Returns:
            :obj:`dict` of `str`: `dict`
        """
        fldr = mkdtemp()
        try:
            url = 'https://{token}:x-oauth-basic@{server}/{repo}'.format(
                **{
                    'token': self.dbconfig.get('git_auth_token', self.ns),
                    'server': self.dbconfig.get('git_server', self.ns),
                    'repo': self.dbconfig.get('git_repo', self.ns)
                })

            policies = {'GLOBAL': {}}
            if self.dbconfig.get('git_no_ssl_verify', self.ns, False):
                os.environ['GIT_SSL_NO_VERIFY'] = '1'

            repo = Repo.clone_from(url, fldr)
            for obj in repo.head.commit.tree:
                name, ext = os.path.splitext(obj.name)

                # Read the standard policies
                if ext == '.json':
                    policies['GLOBAL'][name] = obj.data_stream.read()

                # Read any account role specific policies
                if name == 'roles' and obj.type == 'tree':
                    for account in [x for x in obj.trees]:
                        for role in [x for x in account.trees]:
                            role_policies = {
                                policy.name.replace('.json', ''):
                                policy.data_stream.read()
                                for policy in role.blobs
                                if policy.name.endswith('.json')
                            }

                            if account.name in policies:
                                if role.name in policies[account.name]:
                                    policies[account.name][
                                        role.name] += role_policies
                                else:
                                    policies[account.name][
                                        role.name] = role_policies
                            else:
                                policies[account.name] = {
                                    role.name: role_policies
                                }

            return policies
        finally:
            if os.path.exists(fldr) and os.path.isdir(fldr):
                shutil.rmtree(fldr)

    @staticmethod
    def get_policies_from_aws(client, scope='Local'):
        """Returns a list of all the policies currently applied to an AWS Account. Returns a list containing all the
        policies for the specified scope

        Args:
            client (:obj:`boto3.session.Session`): A boto3 Session object
            scope (`str`): The policy scope to use. Default: Local

        Returns:
            :obj:`list` of `dict`
        """
        done = False
        marker = None
        policies = []

        while not done:
            if marker:
                response = client.list_policies(Marker=marker, Scope=scope)
            else:
                response = client.list_policies(Scope=scope)

            policies += response['Policies']

            if response['IsTruncated']:
                marker = response['Marker']
            else:
                done = True

        return policies

    @staticmethod
    def get_roles(client):
        """Returns a list of all the roles for an account. Returns a list containing all the roles for the account.

        Args:
            client (:obj:`boto3.session.Session`): A boto3 Session object

        Returns:
            :obj:`list` of `dict`
        """
        done = False
        marker = None
        roles = []

        while not done:
            if marker:
                response = client.list_roles(Marker=marker)
            else:
                response = client.list_roles()

            roles += response['Roles']

            if response['IsTruncated']:
                marker = response['Marker']
            else:
                done = True

        return roles

    def create_policy(self, account, client, document, name, arn=None):
        """Create a new IAM policy.

        If the policy already exists, a new version will be added and if needed the oldest policy version not in use
        will be removed. Returns a dictionary containing the policy or version information

        Args:
            account (:obj:`Account`): Account to create the policy on
            client (:obj:`boto3.client`): A boto3 client object
            document (`str`): Policy document
            name (`str`): Name of the policy to create / update
            arn (`str`): Optional ARN for the policy to update

        Returns:
            `dict`
        """
        if not arn and not name:
            raise ValueError(
                'create_policy must be called with either arn or name in the argument list'
            )

        if arn:
            response = client.list_policy_versions(PolicyArn=arn)

            # If we're at the max of the 5 possible versions, remove the oldest version that is not
            # the currently active policy
            if len(response['Versions']) >= 5:
                version = [
                    x for x in sorted(response['Versions'],
                                      key=lambda k: k['CreateDate'])
                    if not x['IsDefaultVersion']
                ][0]

                self.log.info(
                    'Deleting oldest IAM Policy version {}/{}'.format(
                        arn, version['VersionId']))
                client.delete_policy_version(PolicyArn=arn,
                                             VersionId=version['VersionId'])
                auditlog(event='iam.check_roles.delete_policy_version',
                         actor=self.ns,
                         data={
                             'account': account.account_name,
                             'policyName': name,
                             'policyArn': arn,
                             'versionId': version['VersionId']
                         })

            res = client.create_policy_version(PolicyArn=arn,
                                               PolicyDocument=document,
                                               SetAsDefault=True)
        else:
            res = client.create_policy(PolicyName=name,
                                       PolicyDocument=document)

        auditlog(event='iam.check_roles.create_policy',
                 actor=self.ns,
                 data={
                     'account': account.account_name,
                     'policyName': name,
                     'policyArn': arn
                 })
        return res
Exemple #8
0
class DNSCollector(BaseCollector):
    name = 'DNS'
    ns = 'collector_dns'
    type = CollectorType.GLOBAL
    interval = dbconfig.get('interval', ns, 15)
    options = (ConfigOption('enabled', False, 'bool',
                            'Enable the DNS collector plugin'),
               ConfigOption('interval', 15, 'int', 'Run frequency in minutes'),
               ConfigOption('cloudflare_enabled', False, 'bool',
                            'Enable CloudFlare as a source for DNS records'),
               ConfigOption('axfr_enabled', False, 'bool',
                            'Enable using DNS Zone Transfers for records'))

    def __init__(self):
        super().__init__()

        self.axfr_enabled = self.dbconfig.get('axfr_enabled', self.ns, False)
        self.cloudflare_enabled = self.dbconfig.get('cloudflare_enabled',
                                                    self.ns, False)

        self.axfr_accounts = list(AXFRAccount.get_all().values())
        self.cf_accounts = list(CloudFlareAccount.get_all().values())

        self.cloudflare_initialized = defaultdict(lambda: False)
        self.cloudflare_session = {}

    def run(self):
        if self.axfr_enabled:
            try:
                for account in self.axfr_accounts:
                    records = self.get_axfr_records(account.server,
                                                    account.domains)
                    self.process_zones(records, account)
            except:
                self.log.exception('Failed processing domains via AXFR')

        if self.cloudflare_enabled:
            try:
                for account in self.cf_accounts:
                    records = self.get_cloudflare_records(account=account)
                    self.process_zones(records, account)
            except:
                self.log.exception('Failed processing domains via CloudFlare')

    def process_zones(self, zones, account):
        self.log.info('Processing DNS records for {}'.format(
            account.account_name))

        # region Update zones
        existing_zones = DNSZone.get_all(account)
        for data in zones:
            if data['zone_id'] in existing_zones:
                zone = DNSZone.get(data['zone_id'])
                if zone.update(data):
                    self.log.debug('Change detected for DNS zone {}/{}'.format(
                        account.account_name, zone.name))
                    db.session.add(zone.resource)
            else:
                DNSZone.create(data['zone_id'],
                               account_id=account.account_id,
                               properties={
                                   k: v
                                   for k, v in data.items()
                                   if k not in ('records', 'zone_id', 'tags')
                               },
                               tags=data['tags'])

                self.log.debug('Added DNS zone {}/{}'.format(
                    account.account_name, data['name']))

        db.session.commit()

        zk = set(x['zone_id'] for x in zones)
        ezk = set(existing_zones.keys())

        for resource_id in ezk - zk:
            zone = existing_zones[resource_id]

            # Delete all the records for the zone
            for record in zone.records:
                db.session.delete(record.resource)

            db.session.delete(zone.resource)
            self.log.debug('Deleted DNS zone {}/{}'.format(
                account.account_name, zone.name.value))
        db.session.commit()
        # endregion

        # region Update resource records
        for zone in zones:
            try:
                existing_zone = DNSZone.get(zone['zone_id'])
                existing_records = {
                    rec.id: rec
                    for rec in existing_zone.records
                }

                for data in zone['records']:
                    if data['id'] in existing_records:
                        record = existing_records[data['id']]
                        if record.update(data):
                            self.log.debug(
                                'Changed detected for DNSRecord {}/{}/{}'.
                                format(account.account_name, zone.name,
                                       data['name']))
                            db.session.add(record.resource)
                    else:
                        record = DNSRecord.create(
                            data['id'],
                            account_id=account.account_id,
                            properties={
                                k: v
                                for k, v in data.items()
                                if k not in ('records', 'zone_id')
                            },
                            tags={})
                        self.log.debug('Added new DNSRecord {}/{}/{}'.format(
                            account.account_name, zone['name'], data['name']))
                        existing_zone.add_record(record)
                db.session.commit()

                rk = set(x['id'] for x in zone['records'])
                erk = set(existing_records.keys())

                for resource_id in erk - rk:
                    record = existing_records[resource_id]
                    db.session.delete(record.resource)
                    self.log.debug('Deleted DNSRecord {}/{}/{}'.format(
                        account.account_name, zone['zone_id'], record.name))
                db.session.commit()
            except:
                self.log.exception(
                    'Error while attempting to update records for {}/{}'.
                    format(
                        account.account_name,
                        zone['zone_id'],
                    ))
                db.session.rollback()
        # endregion

    @retry
    def get_axfr_records(self, server, domains):
        """Return a `list` of `dict`s containing the zones and their records, obtained from the DNS server

        Returns:
            :obj:`list` of `dict`
        """
        zones = []
        for zoneName in domains:
            try:
                zone = {
                    'zone_id': get_resource_id('axfrz', zoneName),
                    'name': zoneName,
                    'source': 'AXFR',
                    'comment': None,
                    'tags': {},
                    'records': []
                }

                z = dns_zone.from_xfr(query.xfr(server, zoneName))
                rdata_fields = ('name', 'ttl', 'rdata')
                for rr in [
                        dict(zip(rdata_fields, x)) for x in z.iterate_rdatas()
                ]:
                    record_name = rr['name'].derelativize(z.origin).to_text()
                    zone['records'].append({
                        'id':
                        get_resource_id(
                            'axfrr', record_name,
                            ['{}={}'.format(k, str(v))
                             for k, v in rr.items()]),
                        'zone_id':
                        zone['zone_id'],
                        'name':
                        record_name,
                        'value':
                        sorted([rr['rdata'].to_text()]),
                        'type':
                        type_to_text(rr['rdata'].rdtype)
                    })

                if len(zone['records']) > 0:
                    zones.append(zone)

            except Exception as ex:
                self.log.exception(
                    'Failed fetching DNS zone information for {}: {}'.format(
                        zoneName, ex))
                raise

        return zones

    def get_cloudflare_records(self, *, account):
        """Return a `list` of `dict`s containing the zones and their records, obtained from the CloudFlare API

        Returns:
            account (:obj:`CloudFlareAccount`): A CloudFlare Account object
            :obj:`list` of `dict`
        """
        zones = []

        for zobj in self.__cloudflare_list_zones(account=account):
            try:
                self.log.debug('Processing DNS zone CloudFlare/{}'.format(
                    zobj['name']))
                zone = {
                    'zone_id': get_resource_id('cfz', zobj['name']),
                    'name': zobj['name'],
                    'source': 'CloudFlare',
                    'comment': None,
                    'tags': {},
                    'records': []
                }

                for record in self.__cloudflare_list_zone_records(
                        account=account, zoneID=zobj['id']):
                    zone['records'].append({
                        'id':
                        get_resource_id(
                            'cfr', zobj['id'],
                            ['{}={}'.format(k, v) for k, v in record.items()]),
                        'zone_id':
                        zone['zone_id'],
                        'name':
                        record['name'],
                        'value':
                        record['value'],
                        'type':
                        record['type']
                    })

                if len(zone['records']) > 0:
                    zones.append(zone)
            except CloudFlareError:
                self.log.exception(
                    'Failed getting records for CloudFlare zone {}'.format(
                        zobj['name']))

        return zones

    # region Helper functions for CloudFlare
    def __cloudflare_request(self, *, account, path, args=None):
        """Helper function to interact with the CloudFlare API.

        Args:
            account (:obj:`CloudFlareAccount`): CloudFlare Account object
            path (`str`): URL endpoint to communicate with
            args (:obj:`dict` of `str`: `str`): A dictionary of arguments for the endpoint to consume

        Returns:
            `dict`
        """
        if not args:
            args = {}

        if not self.cloudflare_initialized[account.account_id]:
            self.cloudflare_session[account.account_id] = requests.Session()
            self.cloudflare_session[account.account_id].headers.update({
                'X-Auth-Email':
                account.email,
                'X-Auth-Key':
                account.api_key,
                'Content-Type':
                'application/json'
            })
            self.cloudflare_initialized[account.account_id] = True

        if 'per_page' not in args:
            args['per_page'] = 100

        response = self.cloudflare_session[account.account_id].get(
            account.endpoint + path, params=args)
        if response.status_code != 200:
            raise CloudFlareError('Request failed: {}'.format(response.text))

        return response.json()

    def __cloudflare_list_zones(self, *, account, **kwargs):
        """Helper function to list all zones registered in the CloudFlare system. Returns a `list` of the zones

        Args:
            account (:obj:`CloudFlareAccount`): A CloudFlare Account object
            **kwargs (`dict`): Extra arguments to pass to the API endpoint

        Returns:
            `list` of `dict`
        """
        done = False
        zones = []
        page = 1

        while not done:
            kwargs['page'] = page
            response = self.__cloudflare_request(account=account,
                                                 path='/zones',
                                                 args=kwargs)
            info = response['result_info']

            if 'total_pages' not in info or page == info['total_pages']:
                done = True
            else:
                page += 1

            zones += response['result']

        return zones

    def __cloudflare_list_zone_records(self, *, account, zoneID, **kwargs):
        """Helper function to list all records on a CloudFlare DNS Zone. Returns a `dict` containing the records and
        their information.

        Args:
            account (:obj:`CloudFlareAccount`): A CloudFlare Account object
            zoneID (`int`): Internal CloudFlare ID of the DNS zone
            **kwargs (`dict`): Additional arguments to be consumed by the API endpoint

        Returns:
            :obj:`dict` of `str`: `dict`
        """
        done = False
        records = {}
        page = 1

        while not done:
            kwargs['page'] = page
            response = self.__cloudflare_request(
                account=account,
                path='/zones/{}/dns_records'.format(zoneID),
                args=kwargs)
            info = response['result_info']

            # Check if we have received all records, and if not iterate over the result set
            if 'total_pages' not in info or page >= info['total_pages']:
                done = True
            else:
                page += 1

            for record in response['result']:
                if record['name'] in records:
                    records[record['name']]['value'] = sorted(
                        records[record['name']]['value'] + [record['content']])
                else:
                    records[record['name']] = {
                        'name': record['name'],
                        'value': sorted([record['content']]),
                        'type': record['type']
                    }

        return list(records.values())
Exemple #9
0
class DomainHijackAuditor(BaseAuditor):
    """Domain Hijacking Auditor

    Checks DNS resource records for any pointers to non-existing assets in AWS (S3 buckets, Elastic Beanstalks, etc).
    """

    name = 'Domain Hijacking'
    ns = NS_AUDITOR_DOMAIN_HIJACKING
    interval = dbconfig.get('interval', ns, 30)
    options = (
        ConfigOption('enabled', False, 'bool', 'Enable the Domain Hijacking auditor'),
        ConfigOption('interval', 30, 'int', 'Run frequency in minutes'),
        ConfigOption('email_recipients', ['*****@*****.**'], 'array', 'List of emails to receive alerts'),
        ConfigOption('hijack_subject', 'Potential domain hijack detected', 'string',
                     'Email subject for domain hijack notifications'),
        ConfigOption('alert_frequency', 24, 'int', 'How frequent in hours, to alert'),
    )

    def __init__(self):
        super().__init__()

        self.recipients = dbconfig.get('email_recipients', self.ns)
        self.subject = dbconfig.get('hijack_subject', self.ns, 'Potential domain hijack detected')
        self.alert_frequency = dbconfig.get('alert_frequency', self.ns, 24)

    def run(self, *args, **kwargs):
        """Update the cache of all DNS entries and perform checks

        Args:
            *args: Optional list of arguments
            **kwargs: Optional list of keyword arguments

        Returns:
            None
        """
        try:
            zones = list(DNSZone.get_all().values())
            buckets = {k.lower(): v for k, v in S3Bucket.get_all().items()}
            dists = list(CloudFrontDist.get_all().values())
            ec2_public_ips = [x.public_ip for x in EC2Instance.get_all().values() if x.public_ip]
            beanstalks = {x.cname.lower(): x for x in BeanStalk.get_all().values()}

            existing_issues = DomainHijackIssue.get_all()
            issues = []

            # List of different types of domain audits
            auditors = [
                ElasticBeanstalkAudit(beanstalks),
                S3Audit(buckets),
                S3WithoutEndpointAudit(buckets),
                EC2PublicDns(ec2_public_ips),
            ]

            # region Build list of active issues
            for zone in zones:
                for record in zone.records:
                    for auditor in auditors:
                        if auditor.match(record):
                            issues.extend(auditor.audit(record, zone))

            for dist in dists:
                for org in dist.origins:
                    if org['type'] == 's3':
                        bucket = self.return_resource_name(org['source'], 's3')

                        if bucket not in buckets:
                            key = '{} ({})'.format(bucket, dist.type)
                            issues.append({
                                'key': key,
                                'value': 'S3Bucket {} doesnt exist on any known account. Referenced by {} on {}'.format(
                                    bucket,
                                    dist.domain_name,
                                    dist.account,
                                )
                            })
            # endregion

            # region Process new, old, fixed issue lists
            old_issues = {}
            new_issues = {}
            fixed_issues = []

            for data in issues:
                issue_id = get_resource_id('dhi', ['{}={}'.format(k, v) for k, v in data.items()])

                if issue_id in existing_issues:
                    issue = existing_issues[issue_id]

                    if issue.update({'state': 'EXISTING', 'end': None}):
                        db.session.add(issue.issue)

                    old_issues[issue_id] = issue

                else:
                    properties = {
                        'issue_hash': issue_id,
                        'state': 'NEW',
                        'start': datetime.now(),
                        'end': None,
                        'source': data['key'],
                        'description': data['value']
                    }
                    new_issues[issue_id] = DomainHijackIssue.create(issue_id, properties=properties)
            db.session.commit()

            for issue in list(existing_issues.values()):
                if issue.id not in new_issues and issue.id not in old_issues:
                    fixed_issues.append(issue.to_json())
                    db.session.delete(issue.issue)
            # endregion

            # Only alert if its been more than a day since the last alert
            alert_cutoff = datetime.now() - timedelta(hours=self.alert_frequency)
            old_alerts = []
            for issue_id, issue in old_issues.items():
                if issue.last_alert and issue.last_alert < alert_cutoff:
                    if issue.update({'last_alert': datetime.now()}):
                        db.session.add(issue.issue)

                    old_alerts.append(issue)

            db.session.commit()

            self.notify(
                [x.to_json() for x in new_issues.values()],
                [x.to_json() for x in old_alerts],
                fixed_issues
            )
        finally:
            db.session.rollback()

    def notify(self, new_issues, existing_issues, fixed_issues):
        """Send notifications (email, slack, etc.) for any issues that are currently open or has just been closed

        Args:
            new_issues (`list` of :obj:`DomainHijackIssue`): List of newly discovered issues
            existing_issues (`list` of :obj:`DomainHijackIssue`): List of existing open issues
            fixed_issues (`list` of `dict`): List of fixed issues

        Returns:
            None
        """
        if len(new_issues + existing_issues + fixed_issues) > 0:
            maxlen = max(len(x['properties']['source']) for x in (new_issues + existing_issues + fixed_issues)) + 2
            text_tmpl = get_template('domain_hijacking.txt')
            html_tmpl = get_template('domain_hijacking.html')
            issues_text = text_tmpl.render(
                new_issues=new_issues,
                existing_issues=existing_issues,
                fixed_issues=fixed_issues,
                maxlen=maxlen
            )
            issues_html = html_tmpl.render(
                new_issues=new_issues,
                existing_issues=existing_issues,
                fixed_issues=fixed_issues,
                maxlen=maxlen
            )

            try:
                send_notification(
                    subsystem=self.name,
                    recipients=[NotificationContact('email', addr) for addr in self.recipients],
                    subject=self.subject,
                    body_html=issues_html,
                    body_text=issues_text
                )
            except Exception as ex:
                self.log.exception('Failed sending notification email: {}'.format(ex))

    def return_resource_name(self, record, resource_type):
        """ Removes the trailing AWS domain from a DNS record
            to return the resource name

            e.g bucketname.s3.amazonaws.com will return bucketname

        Args:
            record (str): DNS record
            resource_type: AWS Resource type (i.e. S3 Bucket, Elastic Beanstalk, etc..)

        """
        try:
            if resource_type == 's3':
                regex = re.compile('.*(\.(?:s3-|s3){1}(?:.*)?\.amazonaws\.com)')
                bucket_name = record.replace(regex.match(record).group(1), '')
                return bucket_name

        except Exception as e:
            self.log.error('Unable to parse DNS record {} for resource type {}/{}'.format(record, resource_type, e))
            return record
Exemple #10
0
class SQSScheduler(BaseScheduler):
    name = 'SQS Scheduler'
    ns = NS_SCHEDULER_SQS
    options = (
        ConfigOption('queue_region', 'us-west-2', 'string', 'Region of the SQS Queues'),
        ConfigOption('job_queue_url', '', 'string', 'URL of the SQS Queue for pending jobs'),
        ConfigOption('status_queue_url', '', 'string', 'URL of the SQS Queue for worker reports'),
        ConfigOption('job_delay', 2, 'float', 'Time between each scheduled job, in seconds. Can be used to '
                     'avoid spiky load during execution of tasks'),
    )

    def __init__(self):
        """Initialize the SQSScheduler, setting up the process pools, scheduler and connecting to the required
        SQS Queues"""
        super().__init__()

        self.pool = ProcessPoolExecutor(1)
        self.scheduler = APScheduler(
            threadpool=self.pool,
            job_defaults={
                'coalesce': True,
                'misfire_grace_time': 30
            }
        )

        session = get_local_aws_session()
        sqs = session.resource('sqs', self.dbconfig.get('queue_region', self.ns))

        self.job_queue = sqs.Queue(self.dbconfig.get('job_queue_url', self.ns))
        self.status_queue = sqs.Queue(self.dbconfig.get('status_queue_url', self.ns))

    def execute_scheduler(self):
        """Main entry point for the scheduler. This method will start two scheduled jobs, `schedule_jobs` which takes
         care of scheduling the actual SQS messaging and `process_status_queue` which will track the current status
         of the jobs as workers are executing them

        Returns:
            `None`
        """
        try:
            # Schedule periodic scheduling of jobs
            self.scheduler.add_job(
                self.schedule_jobs,
                trigger='interval',
                name='schedule_jobs',
                minutes=15,
                start_date=datetime.now() + timedelta(seconds=1)
            )

            self.scheduler.add_job(
                self.process_status_queue,
                trigger='interval',
                name='process_status_queue',
                seconds=30,
                start_date=datetime.now() + timedelta(seconds=5),
                max_instances=1
            )

            self.scheduler.start()

        except KeyboardInterrupt:
            self.scheduler.shutdown()

    def list_current_jobs(self):
        """Return a list of the currently scheduled jobs in APScheduler

        Returns:
            `dict` of `str`: :obj:`apscheduler/job:Job`
        """
        jobs = {}
        for job in self.scheduler.get_jobs():
            if job.name not in ('schedule_jobs', 'process_status_queue'):
                jobs[job.name] = job

        return jobs

    def schedule_jobs(self):
        """Schedule or remove jobs as needed.

        Checks to see if there are any jobs that needs to be scheduled, after refreshing the database configuration
        as well as the list of collectors and auditors.

        Returns:
            `None`
        """
        self.dbconfig.reload_data()
        self.collectors = {}
        self.auditors = []
        self.load_plugins()

        _, accounts = BaseAccount.search(include_disabled=False)
        current_jobs = self.list_current_jobs()
        new_jobs = []
        batch_id = str(uuid4())

        batch = SchedulerBatch()
        batch.batch_id = batch_id
        batch.status = SchedulerStatus.PENDING
        db.session.add(batch)
        db.session.commit()

        start = datetime.now() + timedelta(seconds=1)
        job_delay = dbconfig.get('job_delay', self.ns, 0.5)

        # region Global collectors (non-aws)
        if CollectorType.GLOBAL in self.collectors:
            for worker in self.collectors[CollectorType.GLOBAL]:
                job_name = get_hash(worker)

                if job_name in current_jobs:
                    continue

                self.scheduler.add_job(
                    self.send_worker_queue_message,
                    trigger='interval',
                    name=job_name,
                    minutes=worker.interval,
                    start_date=start,
                    kwargs={
                        'batch_id': batch_id,
                        'job_name': job_name,
                        'entry_point': worker.entry_point,
                        'worker_args': {}
                    }
                )
                start += timedelta(seconds=job_delay)
        # endregion

        # region AWS collectors
        aws_accounts = list(filter(lambda x: x.account_type == AWSAccount.account_type, accounts))
        if CollectorType.AWS_ACCOUNT in self.collectors:
            for worker in self.collectors[CollectorType.AWS_ACCOUNT]:
                for account in aws_accounts:
                    job_name = get_hash((account.account_name, worker))
                    if job_name in current_jobs:
                        continue

                    new_jobs.append(job_name)

                    self.scheduler.add_job(
                        self.send_worker_queue_message,
                        trigger='interval',
                        name=job_name,
                        minutes=worker.interval,
                        start_date=start,
                        kwargs={
                            'batch_id': batch_id,
                            'job_name': job_name,
                            'entry_point': worker.entry_point,
                            'worker_args': {
                                'account': account.account_name
                            }
                        }
                    )
                    start += timedelta(seconds=job_delay)

        if CollectorType.AWS_REGION in self.collectors:
            for worker in self.collectors[CollectorType.AWS_REGION]:
                for region in AWS_REGIONS:
                    for account in aws_accounts:
                        job_name = get_hash((account.account_name, region, worker))

                        if job_name in current_jobs:
                            continue

                        new_jobs.append(job_name)

                        self.scheduler.add_job(
                            self.send_worker_queue_message,
                            trigger='interval',
                            name=job_name,
                            minutes=worker.interval,
                            start_date=start,
                            kwargs={
                                'batch_id': batch_id,
                                'job_name': job_name,
                                'entry_point': worker.entry_point,
                                'worker_args': {
                                    'account': account.account_name,
                                    'region': region
                                }
                            }
                        )
                        start += timedelta(seconds=job_delay)
        # endregion

        # region Auditors
        if app_config.log_level == 'DEBUG':
            audit_start = start + timedelta(seconds=5)
        else:
            audit_start = start + timedelta(minutes=5)

        for worker in self.auditors:
            job_name = get_hash((worker,))
            if job_name in current_jobs:
                continue

            new_jobs.append(job_name)

            self.scheduler.add_job(
                self.send_worker_queue_message,
                trigger='interval',
                name=job_name,
                minutes=worker.interval,
                start_date=audit_start,
                kwargs={
                    'batch_id': batch_id,
                    'job_name': job_name,
                    'entry_point': worker.entry_point,
                    'worker_args': {}
                }
            )
            audit_start += timedelta(seconds=job_delay)
        # endregion

    def send_worker_queue_message(self, *, batch_id, job_name, entry_point, worker_args, retry_count=0):
        """Send a message to the `worker_queue` for a worker to execute the requests job

        Args:
            batch_id (`str`): Unique ID of the batch the job belongs to
            job_name (`str`): Non-unique ID of the job. This is used to ensure that the same job is only scheduled
            a single time per batch
            entry_point (`dict`): A dictionary providing the entry point information for the worker to load the class
            worker_args (`dict`): A dictionary with the arguments required by the worker class (if any, can be an
            empty dictionary)
            retry_count (`int`): The number of times this one job has been attempted to be executed. If a job fails to
            execute after 3 retries it will be marked as failed

        Returns:
            `None`
        """
        try:
            job_id = str(uuid4())
            self.job_queue.send_message(
                MessageBody=json.dumps({
                    'batch_id': batch_id,
                    'job_id': job_id,
                    'job_name': job_name,
                    'entry_point': entry_point,
                    'worker_args': worker_args,
                }),
                MessageDeduplicationId=job_id,
                MessageGroupId=batch_id,
                MessageAttributes={
                    'RetryCount': {
                        'StringValue': str(retry_count),
                        'DataType': 'Number'

                    }
                }
            )

            if retry_count == 0:
                job = SchedulerJob()
                job.job_id = job_id
                job.batch_id = batch_id
                job.status = SchedulerStatus.PENDING
                job.data = worker_args

                db.session.add(job)
                db.session.commit()
        except:
            self.log.exception('Error when processing worker task')

    def execute_worker(self):
        """Retrieve a message from the `worker_queue` and process the request.

        This function will read a single message from the `worker_queue` and load the specified `EntryPoint`
        and execute the worker with the provided arguments. Upon completion (failure or otherwise) a message is sent
        to the `status_queue` information the scheduler about the return status (success/failure) of the worker

        Returns:
            `None`
        """
        try:
            try:
                messages = self.job_queue.receive_messages(
                    MaxNumberOfMessages=1,
                    MessageAttributeNames=('RetryCount',)
                )

            except ClientError:
                self.log.exception('Failed fetching messages from SQS queue')
                return

            if not messages:
                self.log.debug('No pending jobs')
                return

            for message in messages:
                try:
                    retry_count = int(message.message_attributes['RetryCount']['StringValue'])

                    data = json.loads(message.body)
                    try:
                        # SQS FIFO queues will not allow another thread to get any new messages until the messages
                        # in-flight are returned to the queue or deleted, so we remove the message from the queue as
                        # soon as we've loaded the data
                        self.send_status_message(data['job_id'], SchedulerStatus.STARTED)
                        message.delete()

                        cls = self.get_class_from_ep(data['entry_point'])
                        worker = cls(**data['worker_args'])
                        if hasattr(worker, 'type'):
                            if worker.type == CollectorType.GLOBAL:
                                self.log.info('RUN_INFO: {} starting at {}, next run will be at approximately {}'.format(data['entry_point']['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), (datetime.now() + timedelta(minutes=worker.interval)).strftime("%Y-%m-%d %H:%M:%S")))
                            elif worker.type == CollectorType.AWS_REGION:
                                self.log.info('RUN_INFO: {} starting at {} for account {} / region {}, next run will be at approximately {}'.format(data['entry_point']['module_name'],	datetime.now().strftime("%Y-%m-%d %H:%M:%S"), data['worker_args']['account'], data['worker_args']['region'], (datetime.now() + timedelta(minutes=worker.interval)).strftime("%Y-%m-%d %H:%M:%S")))
                            elif worker.type == CollectorType.AWS_ACCOUNT:
                                self.log.info('RUN_INFO: {} starting at {} for account {} next run will be at approximately {}'.format(data['entry_point']['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), data['worker_args']['account'], (datetime.now() + timedelta(minutes=worker.interval)).strftime("%Y-%m-%d %H:%M:%S")))
                        else:
                            self.log.info('RUN_INFO: {} starting at {} next run will be at approximately {}'.format(data['entry_point']['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), (datetime.now() + timedelta(minutes=worker.interval)).strftime("%Y-%m-%d %H:%M:%S")))
                        worker.run()

                        self.send_status_message(data['job_id'], SchedulerStatus.COMPLETED)
                    except InquisitorError:
                        # If the job failed for some reason, reschedule it unless it has already been retried 3 times
                        if retry_count >= 3:
                            self.send_status_message(data['job_id'], SchedulerStatus.FAILED)
                        else:
                            self.send_worker_queue_message(
                                batch_id=data['batch_id'],
                                job_name=data['job_name'],
                                entry_point=data['entry_point'],
                                worker_args=data['worker_args'],
                                retry_count=retry_count + 1
                            )
                except:
                    self.log.exception('Failed processing scheduler job: {}'.format(message.body))

        except KeyboardInterrupt:
            self.log.info('Shutting down worker thread')


    @retry
    def send_status_message(self, object_id, status):
        """Send a message to the `status_queue` to update a job's status.

        Returns `True` if the message was sent, else `False`

        Args:
            object_id (`str`): ID of the job that was executed
            status (:obj:`SchedulerStatus`): Status of the job

        Returns:
            `bool`
        """
        try:
            body = json.dumps({
                'id': object_id,
                'status': status
            })

            self.status_queue.send_message(
                MessageBody=body,
                MessageGroupId='job_status',
                MessageDeduplicationId=get_hash((object_id, status))
            )
            return True
        except Exception as ex:
            print(ex)
            return False

    @retry
    def process_status_queue(self):
        """Process all messages in the `status_queue` and check for any batches that needs to change status

        Returns:
            `None`
        """
        self.log.debug('Start processing status queue')
        while True:
            messages = self.status_queue.receive_messages(MaxNumberOfMessages=10)

            if not messages:
                break

            for message in messages:
                data = json.loads(message.body)
                job = SchedulerJob.get(data['id'])
                try:
                    if job and job.update_status(data['status']):
                        db.session.commit()
                except SchedulerError as ex:
                    if hasattr(ex, 'message') and ex.message == 'Attempting to update already completed job':
                        pass

                message.delete()

        # Close any batch that is now complete
        open_batches = db.SchedulerBatch.find(SchedulerBatch.status < SchedulerStatus.COMPLETED)
        for batch in open_batches:
            open_jobs = list(filter(lambda x: x.status < SchedulerStatus.COMPLETED, batch.jobs))
            if not open_jobs:
                open_batches.remove(batch)
                batch.update_status(SchedulerStatus.COMPLETED)
                self.log.debug('Closed completed batch {}'.format(batch.batch_id))
            else:
                started_jobs = list(filter(lambda x: x.status > SchedulerStatus.PENDING, open_jobs))
                if batch.status == SchedulerStatus.PENDING and len(started_jobs) > 0:
                    batch.update_status(SchedulerStatus.STARTED)
                    self.log.debug('Started batch manually {}'.format(batch.batch_id))

        # Check for stale batches / jobs
        for batch in open_batches:
            if batch.started < datetime.now() - timedelta(hours=2):
                self.log.warning('Closing a stale scheduler batch: {}'.format(batch.batch_id))
                for job in batch.jobs:
                    if job.status < SchedulerStatus.COMPLETED:
                        job.update_status(SchedulerStatus.ABORTED)
                batch.update_status(SchedulerStatus.ABORTED)
        db.session.commit()
Exemple #11
0
class SlackNotifier(BaseNotifier):
    name = 'Slack Notifier'
    ns = NS_SLACK
    notifier_type = 'slack'
    validation = r'^(#[a-zA-Z0-9\-_]+|{})$'.format(RGX_EMAIL_VALIDATION_PATTERN)
    options = (
        ConfigOption('enabled', False, 'bool', 'Enable the Slack notifier plugin'),
        ConfigOption('api_key', '', 'string', 'API token for the slack notifications'),
        ConfigOption('bot_name', 'Inquisitor', 'string', 'Name of the bot in Slack'),
        ConfigOption('bot_color', '#607d8b', 'string', 'Hex formatted color code for notifications'),
    )

    def __init__(self, api_key=None):
        super().__init__()

        if not self.enabled:
            raise SlackError('Slack messaging is disabled')

        self.slack_client = SlackClient(api_key or dbconfig.get('api_key', self.ns))
        self.bot_name = dbconfig.get('bot_name', self.ns, 'Inquisitor')
        self.color = dbconfig.get('bot_color', self.ns, '#607d8b')

        if not self._check_credentials():
            raise SlackError('Failed authenticating to the slack api. Please check the API is valid')

    def _check_credentials(self):
        try:
            response = self.slack_client.api_call('auth.test')
            return response['ok']
        except Exception:
            return False

    def _get_user_id_by_email(self, email):
        try:
            response = self.slack_client.api_call('users.list')

            if not response['ok']:
                raise SlackError('Failed to list Slack users: {}'.format(response['error']))

            user = list(filter(lambda x: x['profile'].get('email') == email, response['members']))
            if user:
                return user.pop()['id']
            else:
                SlackError('Failed to get user from Slack!')

        except Exception as ex:
            raise SlackError(ex)

    def _get_channel_for_user(self, user_email):
        user_id = self._get_user_id_by_email(user_email)
        try:
            response = self.slack_client.api_call('im.open', user=user_id)

            if not response['ok']:
                raise SlackError('Failed to get channel for user: {}'.format(response['error']))

            return response['channel']['id']
        except Exception as ex:
            raise SlackError(ex)

    def _send_message(self, target_type, target, message, title):
        if target_type == 'user':
            channel = self._get_channel_for_user(target)
        else:
            channel = target

        result = self.slack_client.api_call(
            'chat.postMessage',
            channel=channel,
            attachments=[
                {
                    'fallback': message,
                    'color': self.color,
                    'title': title,
                    'text': message
                }
            ],
            username=self.bot_name
        )
        if not result.get('ok', False):
            raise SlackError('Failed to send message: {}'.format(result['error']))

    def notify(self, subsystem, recipient, subject, body_html, body_text):
        """You can send messages either to channels and private groups by using the following formats

        #channel-name
        @username-direct-message

        Args:
            subsystem (`str`): Name of the subsystem originating the notification
            recipient (`str`): Recipient
            subject (`str`): Subject / title of the notification, not used for this notifier
            body_html (`str)`: HTML formatted version of the message, not used for this notifier
            body_text (`str`): Text formatted version of the message

        Returns:
            `None`
        """
        if not re.match(self.validation, recipient, re.I):
            raise ValueError('Invalid recipient provided')

        if recipient.startswith('#'):
            target_type = 'channel'

        elif recipient.find('@') != -1:
            target_type = 'user'

        else:
            self.log.error('Unknown contact type for Slack: {}'.format(recipient))
            return

        try:
            self._send_message(
                target_type=target_type,
                target=recipient,
                message=body_text,
                title=subject
            )
        except SlackError as ex:
            self.log.error('Failed sending message to {}: {}'.format(recipient, ex))

    @staticmethod
    @deprecated('send_message has been deprecated, use cloud_inquisitor.utils.send_notifications instead')
    def send_message(contacts, message):
        """List of contacts the send the message to. You can send messages either to channels and private groups by
        using the following formats

        #channel-name
        @username-direct-message

        If the channel is the name of a private group / channel, you must first invite the bot to the channel to ensure
        it is allowed to send messages to the group.

        Returns true if the message was sent, else `False`

        Args:
            contacts (:obj:`list` of `str`,`str`): List of contacts
            message (str): Message to send

        Returns:
            `bool`
        """
        if type(contacts) == str:
            contacts = [contacts]

        recipients = list(set(contacts))

        send_notification(
            subsystem='UNKNOWN',
            recipients=[NotificationContact('slack', x) for x in recipients],
            subject=None,
            body_html=message,
            body_text=message
        )
Exemple #12
0
class CloudTrailAuditor(BaseAuditor):
    """CloudTrail auditor

    Ensures that CloudTrail is enabled and logging to a central location and that SNS/SQS notifications are enabled
    and being sent to the correct queues for the CloudTrail Logs application
    """

    name = 'CloudTrail'
    ns = NS_AUDITOR_CLOUDTRAIL
    interval = dbconfig.get('interval', ns, 60)
    options = (
        ConfigOption('enabled', False, 'bool',
                     'Enable the Cloudtrail auditor'),
        ConfigOption('interval', 60, 'int', 'Run frequency in minutes'),
        ConfigOption(
            'bucket_account', 'CHANGE ME', 'string',
            'Name of the account in which to create the S3 bucket where CloudTrail logs will be delivered. '
            'The account must exist in the accounts section of the tool'),
        ConfigOption('bucket_name', 'CHANGE ME', 'string',
                     'Name of the S3 bucket to send CloudTrail logs to'),
        ConfigOption('bucket_region', 'us-west-2', 'string',
                     'Region for the S3 bucket for CloudTrail logs'),
        ConfigOption('global_cloudtrail_region', 'us-west-2', 'string',
                     'Region where to enable the global Cloudtrail'),
        ConfigOption('sns_topic_name', 'CHANGE ME', 'string',
                     'Name of the SNS topic for CloudTrail log delivery'),
        ConfigOption(
            'sqs_queue_account', 'CHANGE ME', 'string',
            'Name of the account which owns the SQS queue for CloudTrail log delivery notifications. '
            'This account must exist in the accounts section of the tool'),
        ConfigOption('sqs_queue_name', 'SET ME', 'string',
                     'Name of the SQS queue'),
        ConfigOption('sqs_queue_region', 'us-west-2', 'string',
                     'Region for the SQS queue'),
        ConfigOption('trail_name', 'Cinq_Auditing', 'string',
                     'Name of the CloudTrail trail to create'),
    )

    def run(self, *args, **kwargs):
        """Entry point for the scheduler

        Args:
            *args: Optional arguments
            **kwargs: Optional keyword arguments

        Returns:
            None
        """
        accounts = list(AWSAccount.get_all(include_disabled=False).values())

        # S3 Bucket config
        s3_acl = get_template('cloudtrail_s3_bucket_policy.json')
        s3_bucket_name = self.dbconfig.get('bucket_name', self.ns)
        s3_bucket_region = self.dbconfig.get('bucket_region', self.ns,
                                             'us-west-2')
        s3_bucket_account = AWSAccount.get(
            self.dbconfig.get('bucket_account', self.ns))
        CloudTrail.create_s3_bucket(s3_bucket_name, s3_bucket_region,
                                    s3_bucket_account, s3_acl)

        self.validate_sqs_policy(accounts)

        for account in accounts:
            ct = CloudTrail(account, s3_bucket_name, s3_bucket_region,
                            self.log)
            ct.run()

    def validate_sqs_policy(self, accounts):
        """Given a list of accounts, ensures that the SQS policy allows all the accounts to write to the queue

        Args:
            accounts (`list` of :obj:`Account`): List of accounts

        Returns:
            `None`
        """
        sqs_queue_name = self.dbconfig.get('sqs_queue_name', self.ns)
        sqs_queue_region = self.dbconfig.get('sqs_queue_region', self.ns)
        sqs_account = AWSAccount.get(
            self.dbconfig.get('sqs_queue_account', self.ns))
        session = get_aws_session(sqs_account)

        sqs = session.client('sqs', region_name=sqs_queue_region)
        sqs_queue_url = sqs.get_queue_url(
            QueueName=sqs_queue_name,
            QueueOwnerAWSAccountId=sqs_account.account_number)
        sqs_attribs = sqs.get_queue_attributes(
            QueueUrl=sqs_queue_url['QueueUrl'], AttributeNames=['Policy'])

        policy = json.loads(sqs_attribs['Attributes']['Policy'])

        for account in accounts:
            arn = 'arn:aws:sns:*:{}:{}'.format(account.account_number,
                                               sqs_queue_name)
            if arn not in policy['Statement'][0]['Condition'][
                    'ForAnyValue:ArnEquals']['aws:SourceArn']:
                self.log.warning(
                    'SQS policy is missing condition for ARN {}'.format(arn))
                policy['Statement'][0]['Condition']['ForAnyValue:ArnEquals'][
                    'aws:SourceArn'].append(arn)

        sqs.set_queue_attributes(QueueUrl=sqs_queue_url['QueueUrl'],
                                 Attributes={'Policy': json.dumps(policy)})
Exemple #13
0
class EmailNotifier(BaseNotifier):
    name = 'Email Notifier'
    ns = NS_EMAIL
    enabled = dbconfig.get('enabled', ns, True)
    options = (
        ConfigOption('enabled', True, 'bool', 'Enable the Email notifier plugin'),
        ConfigOption('from_address', '*****@*****.**', 'string', 'Sender address for emails'),
        ConfigOption('method', 'ses', 'string', 'EMail sending method, either ses or smtp'),
        ConfigOption('from_arn', '', 'string',
            'If using cross-account SES, this is the "From ARN", otherwise leave blank'
        ),
        ConfigOption('return_path_arn', '', 'string',
            'If using cross-account SES, this is the "Return Path ARN", otherwise leave blank'
        ),
        ConfigOption('source_arn', '', 'string',
            'If using cross-account SES, this is the "Source ARN", otherwise leave blank'
        ),
        ConfigOption('ses_region', 'us-west-2', 'string', 'Which SES region to send emails from'),
        ConfigOption('smtp_server', 'localhost', 'string', 'Address of the SMTP server to use'),
        ConfigOption('smtp_port', 25, 'int', 'Port for the SMTP server'),
        ConfigOption('smtp_username', '', 'string',
            'Username for SMTP authentication. Leave blank for no authentication'
        ),
        ConfigOption('smtp_password', '', 'string',
            'Password for SMTP authentication. Leave blank for no authentication'
        ),
        ConfigOption('smtp_tls', False, 'bool', 'Use TLS for sending emails'),
    )
Exemple #14
0
class EBSAuditor(BaseAuditor):
    """Known issue: if this runs before collector, we don't have EBSVolume or EBSVolumeAttachment data."""
    name = 'EBS Auditor'
    ns = NS_AUDITOR_EBS
    interval = dbconfig.get('interval', ns, 1440)
    options = (
        ConfigOption('enabled', False, 'bool', 'Enable the EBS auditor'),
        ConfigOption('interval', 1440, 'int',
                     'How often the auditor runs, in minutes'),
        ConfigOption('renotify_delay_days', 14, 'int',
                     'Send another notifications n days after the last'),
        ConfigOption('email_subject', 'Unattached EBS Volumes', 'string',
                     'Subject of the notification emails'),
        ConfigOption(
            'ignore_tags', ['cinq:ignore'], 'array',
            'A list of tags that will cause the auditor to ignore the volume'))

    def __init__(self):
        super().__init__()
        self.subject = self.dbconfig.get('email_subject', self.ns)

    def run(self, *args, **kwargs):
        """Main execution point for the auditor

        Args:
            *args:
            **kwargs:

        Returns:
            `None`
        """
        self.log.debug('Starting EBSAuditor')
        data = self.update_data()

        notices = defaultdict(list)
        for account, issues in data.items():
            for issue in issues:
                for recipient in account.contacts:
                    notices[NotificationContact(
                        type=recipient['type'],
                        value=recipient['value'])].append(issue)

        self.notify(notices)

    def update_data(self):
        """Update the database with the current state and return a dict containing the new / updated and fixed
        issues respectively, keyed by the account object

        Returns:
            `dict`
        """
        existing_issues = EBSVolumeAuditIssue.get_all()

        volumes = self.get_unattached_volumes()
        new_issues = self.process_new_issues(volumes, existing_issues)
        fixed_issues = self.process_fixed_issues(volumes, existing_issues)

        # region Process the data to be returned
        output = defaultdict(list)
        for acct, data in new_issues.items():
            output[acct] += data
        # endregion

        # region Update the database with the changes pending
        for issues in new_issues.values():
            for issue in issues:
                db.session.add(issue.issue)

        for issue in fixed_issues:
            db.session.delete(issue.issue)

        db.session.commit()
        # endregion

        return output

    def get_unattached_volumes(self):
        """Build a list of all volumes missing tags and not ignored. Returns a `dict` keyed by the issue_id with the
        volume as the value

        Returns:
            :obj:`dict` of `str`: `EBSVolume`
        """
        volumes = {}
        ignored_tags = dbconfig.get('ignore_tags', self.ns)
        for volume in EBSVolume.get_all().values():
            issue_id = get_resource_id('evai', volume.id)

            if len(volume.attachments) == 0:
                if len(
                        list(
                            filter(
                                set(ignored_tags).__contains__,
                                [tag.key for tag in volume.tags]))):
                    continue

                volumes[issue_id] = volume

        return volumes

    def process_new_issues(self, volumes, existing_issues):
        """Takes a dict of existing volumes missing tags and a dict of existing issues, and finds any new or updated
        issues.

        Args:
            volumes (:obj:`dict` of `str`: `EBSVolume`): Dict of current volumes with issues
            existing_issues (:obj:`dict` of `str`: `EBSVolumeAuditIssue`): Current list of issues

        Returns:
            :obj:`dict` of `str`: `EBSVolumeAuditIssue`
        """
        new_issues = {}
        for issue_id, volume in volumes.items():
            state = EBSIssueState.DETECTED.value

            if issue_id in existing_issues:
                issue = existing_issues[issue_id]

                data = {
                    'state': state,
                    'notes': issue.notes,
                    'last_notice': issue.last_notice
                }
                if issue.update(data):
                    new_issues.setdefault(issue.volume.account,
                                          []).append(issue)
                    self.log.debug(
                        'Updated EBSVolumeAuditIssue {}'.format(issue_id))

            else:
                properties = {
                    'volume_id': volume.id,
                    'account_id': volume.account_id,
                    'location': volume.location,
                    'state': state,
                    'last_change': datetime.now(),
                    'last_notice': None,
                    'notes': []
                }

                issue = EBSVolumeAuditIssue.create(issue_id,
                                                   properties=properties)
                new_issues.setdefault(issue.volume.account, []).append(issue)

        return new_issues

    def process_fixed_issues(self, volumes, existing_issues):
        """Provided a list of volumes and existing issues, returns a list of fixed issues to be deleted

        Args:
            volumes (`dict`): A dictionary keyed on the issue id, with the :obj:`Volume` object as the value
            existing_issues (`dict`): A dictionary keyed on the issue id, with the :obj:`EBSVolumeAuditIssue` object as
            the value

        Returns:
            :obj:`list` of :obj:`EBSVolumeAuditIssue`
        """
        fixed_issues = []
        for issue_id, issue in list(existing_issues.items()):
            if issue_id not in volumes:
                fixed_issues.append(issue)

        return fixed_issues

    def notify(self, notices):
        """Send notifications to the users via. the provided methods

        Args:
            notices (:obj:`dict` of `str`: `dict`): List of the notifications to send

        Returns:
            `None`
        """
        issues_html = get_template('unattached_ebs_volume.html')
        issues_text = get_template('unattached_ebs_volume.txt')

        for recipient, issues in list(notices.items()):
            if issues:
                message_html = issues_html.render(issues=issues)
                message_text = issues_text.render(issues=issues)

                send_notification(subsystem=self.name,
                                  recipients=[recipient],
                                  subject=self.subject,
                                  body_html=message_html,
                                  body_text=message_text)
Exemple #15
0
class AWSRegionCollector(BaseCollector):
    name = 'EC2 Region Collector'
    ns = 'collector_ec2'
    type = CollectorType.AWS_REGION
    interval = dbconfig.get('interval', ns, 15)
    options = (
        ConfigOption('enabled', True, 'bool',
                     'Enable the AWS Region-based Collector'),
        ConfigOption('interval', 15, 'int', 'Run frequency, in minutes'),
        ConfigOption('max_instances', 1000, 'int',
                     'Maximum number of instances per API call'),
    )

    def __init__(self, account, region):
        super().__init__()

        if type(account) == str:
            account = AWSAccount.get(account)

        if not isinstance(account, AWSAccount):
            raise InquisitorError(
                'The AWS Collector only supports AWS Accounts, got {}'.format(
                    account.__class__.__name__))

        self.account = account
        self.region = region
        self.session = get_aws_session(self.account)

    def run(self, *args, **kwargs):
        try:
            self.update_instances()
            self.update_volumes()
            self.update_snapshots()
            self.update_amis()
            self.update_beanstalks()
            self.update_vpcs()
            self.update_elbs()
        except Exception as ex:
            self.log.exception(ex)
            raise
        finally:
            del self.session

    @retry
    def update_instances(self):
        """Update list of EC2 Instances for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating EC2Instances for {}/{}'.format(
            self.account.account_name, self.region))
        ec2 = self.session.resource('ec2', region_name=self.region)

        try:
            existing_instances = EC2Instance.get_all(self.account, self.region)
            instances = {}
            api_instances = {x.id: x for x in ec2.instances.all()}

            try:
                for instance_id, data in api_instances.items():
                    if data.instance_id in existing_instances:
                        instance = existing_instances[instance_id]

                        if data.state['Name'] not in ('terminated',
                                                      'shutting-down'):
                            instances[instance_id] = instance

                            # Add object to transaction if it changed
                            if instance.update(data):
                                self.log.debug(
                                    'Updating info for instance {} in {}/{}'.
                                    format(instance.resource.resource_id,
                                           self.account.account_name,
                                           self.region))
                                db.session.add(instance.resource)
                    else:
                        # New instance, if its not in state=terminated
                        if data.state['Name'] in ('terminated',
                                                  'shutting-down'):
                            continue

                        tags = {
                            tag['Key']: tag['Value']
                            for tag in data.tags or {}
                        }
                        properties = {
                            'launch_date':
                            to_utc_date(data.launch_time).isoformat(),
                            'state':
                            data.state['Name'],
                            'instance_type':
                            data.instance_type,
                            'public_ip':
                            getattr(data, 'public_ip_address', None),
                            'public_dns':
                            getattr(data, 'public_dns_address', None),
                            'created':
                            isoformat(datetime.now())
                        }

                        instance = EC2Instance.create(
                            data.instance_id,
                            account_id=self.account.account_id,
                            location=self.region,
                            properties=properties,
                            tags=tags)

                        instances[instance.resource.resource_id] = instance
                        self.log.debug('Added new EC2Instance {}/{}/{}'.format(
                            self.account.account_name, self.region,
                            instance.resource.resource_id))

                # Check for deleted instances
                ik = set(list(instances.keys()))
                eik = set(list(existing_instances.keys()))

                for instanceID in eik - ik:
                    db.session.delete(existing_instances[instanceID].resource)
                    self.log.debug('Deleted EC2Instance {}/{}/{}'.format(
                        self.account.account_name, self.region, instanceID))

                db.session.commit()
            except:
                db.session.rollback()
                raise
        finally:
            del ec2

    @retry
    def update_amis(self):
        """Update list of AMIs for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating AMIs for {}/{}'.format(
            self.account.account_name, self.region))
        ec2 = self.session.resource('ec2', region_name=self.region)

        try:
            existing_images = AMI.get_all(self.account, self.region)
            images = {x.id: x for x in ec2.images.filter(Owners=['self'])}

            for data in list(images.values()):
                if data.id in existing_images:
                    ami = existing_images[data.id]
                    if ami.update(data):
                        self.log.debug(
                            'Changed detected for AMI {}/{}/{}'.format(
                                self.account.account_name, self.region,
                                ami.resource.resource_id))
                else:
                    properties = {
                        'architecture':
                        data.architecture,
                        'creation_date':
                        parse_date(data.creation_date
                                   or '1970-01-01 00:00:00'),
                        'description':
                        data.description,
                        'name':
                        data.name,
                        'platform':
                        data.platform or 'Linux',
                        'state':
                        data.state,
                    }
                    tags = {
                        tag['Key']: tag['Value']
                        for tag in data.tags or {}
                    }

                    AMI.create(data.id,
                               account_id=self.account.account_id,
                               location=self.region,
                               properties=properties,
                               tags=tags)

                    self.log.debug('Added new AMI {}/{}/{}'.format(
                        self.account.account_name, self.region, data.id))
            db.session.commit()

            # Check for deleted instances
            ik = set(list(images.keys()))
            eik = set(list(existing_images.keys()))

            try:
                for image_id in eik - ik:
                    db.session.delete(existing_images[image_id].resource)
                    self.log.debug('Deleted AMI {}/{}/{}'.format(
                        self.account.account_name,
                        self.region,
                        image_id,
                    ))

                db.session.commit()
            except:
                db.session.rollback()
        finally:
            del ec2

    @retry
    def update_volumes(self):
        """Update list of EBS Volumes for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating EBSVolumes for {}/{}'.format(
            self.account.account_name, self.region))
        ec2 = self.session.resource('ec2', region_name=self.region)

        try:
            existing_volumes = EBSVolume.get_all(self.account, self.region)
            volumes = {x.id: x for x in ec2.volumes.all()}

            for data in list(volumes.values()):
                if data.id in existing_volumes:
                    vol = existing_volumes[data.id]
                    if vol.update(data):
                        self.log.debug(
                            'Changed detected for EBSVolume {}/{}/{}'.format(
                                self.account.account_name, self.region,
                                vol.resource.resource_id))

                else:
                    properties = {
                        'create_time':
                        data.create_time,
                        'encrypted':
                        data.encrypted,
                        'iops':
                        data.iops or 0,
                        'kms_key_id':
                        data.kms_key_id,
                        'size':
                        data.size,
                        'state':
                        data.state,
                        'snapshot_id':
                        data.snapshot_id,
                        'volume_type':
                        data.volume_type,
                        'attachments':
                        sorted([x['InstanceId'] for x in data.attachments])
                    }
                    tags = {t['Key']: t['Value'] for t in data.tags or {}}
                    vol = EBSVolume.create(data.id,
                                           account_id=self.account.account_id,
                                           location=self.region,
                                           properties=properties,
                                           tags=tags)

                    self.log.debug('Added new EBSVolume {}/{}/{}'.format(
                        self.account.account_name, self.region,
                        vol.resource.resource_id))
            db.session.commit()

            vk = set(list(volumes.keys()))
            evk = set(list(existing_volumes.keys()))
            try:
                for volumeID in evk - vk:
                    db.session.delete(existing_volumes[volumeID].resource)
                    self.log.debug('Deleted EBSVolume {}/{}/{}'.format(
                        volumeID, self.account.account_name, self.region))

                db.session.commit()
            except:
                self.log.exception('Failed removing deleted volumes')
                db.session.rollback()
        finally:
            del ec2

    @retry
    def update_snapshots(self):
        """Update list of EBS Snapshots for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating EBSSnapshots for {}/{}'.format(
            self.account.account_name, self.region))
        ec2 = self.session.resource('ec2', region_name=self.region)

        try:
            existing_snapshots = EBSSnapshot.get_all(self.account, self.region)
            snapshots = {
                x.id: x
                for x in ec2.snapshots.filter(
                    OwnerIds=[self.account.account_number])
            }

            for data in list(snapshots.values()):
                if data.id in existing_snapshots:
                    snapshot = existing_snapshots[data.id]
                    if snapshot.update(data):
                        self.log.debug(
                            'Change detected for EBSSnapshot {}/{}/{}'.format(
                                self.account.account_name, self.region,
                                snapshot.resource.resource_id))

                else:
                    properties = {
                        'create_time': data.start_time,
                        'encrypted': data.encrypted,
                        'kms_key_id': data.kms_key_id,
                        'state': data.state,
                        'state_message': data.state_message,
                        'volume_id': data.volume_id,
                        'volume_size': data.volume_size,
                    }
                    tags = {t['Key']: t['Value'] for t in data.tags or {}}

                    snapshot = EBSSnapshot.create(
                        data.id,
                        account_id=self.account.account_id,
                        location=self.region,
                        properties=properties,
                        tags=tags)

                    self.log.debug('Added new EBSSnapshot {}/{}/{}'.format(
                        self.account.account_name, self.region,
                        snapshot.resource.resource_id))

            db.session.commit()

            vk = set(list(snapshots.keys()))
            evk = set(list(existing_snapshots.keys()))
            try:
                for snapshotID in evk - vk:
                    db.session.delete(existing_snapshots[snapshotID].resource)
                    self.log.debug('Deleted EBSSnapshot {}/{}/{}'.format(
                        self.account.account_name, self.region, snapshotID))

                db.session.commit()
            except:
                self.log.exception('Failed removing deleted snapshots')
                db.session.rollback()
        finally:
            del ec2

    @retry
    def update_beanstalks(self):
        """Update list of Elastic BeanStalks for the account / region

        Returns:
            `None`
        """
        self.log.debug(
            'Updating ElasticBeanStalk environments for {}/{}'.format(
                self.account.account_name, self.region))
        ebclient = self.session.client('elasticbeanstalk',
                                       region_name=self.region)

        try:
            existing_beanstalks = BeanStalk.get_all(self.account, self.region)
            beanstalks = {}
            # region Fetch elastic beanstalks
            for env in ebclient.describe_environments()['Environments']:
                # Only get information for HTTP (non-worker) Beanstalks
                if env['Tier']['Type'] == 'Standard':
                    if 'CNAME' in env:
                        beanstalks[env['EnvironmentId']] = {
                            'id': env['EnvironmentId'],
                            'environment_name': env['EnvironmentName'],
                            'application_name': env['ApplicationName'],
                            'cname': env['CNAME']
                        }
                    else:
                        self.log.warning(
                            'Found a BeanStalk that does not have a CNAME: {} in {}/{}'
                            .format(env['EnvironmentName'], self.account,
                                    self.region))
                else:
                    self.log.debug(
                        'Skipping worker tier ElasticBeanstalk environment {}/{}/{}'
                        .format(self.account.account_name, self.region,
                                env['EnvironmentName']))
            # endregion

            try:
                for data in beanstalks.values():
                    if data['id'] in existing_beanstalks:
                        beanstalk = existing_beanstalks[data['id']]
                        if beanstalk.update(data):
                            self.log.debug(
                                'Change detected for ElasticBeanStalk {}/{}/{}'
                                .format(self.account.account_name, self.region,
                                        data['id']))
                    else:
                        bid = data.pop('id')
                        tags = {}
                        BeanStalk.create(bid,
                                         account_id=self.account.account_id,
                                         location=self.region,
                                         properties=data,
                                         tags=tags)

                        self.log.debug(
                            'Added new ElasticBeanStalk {}/{}/{}'.format(
                                self.account.account_name, self.region, bid))
                db.session.commit()

                bk = set(beanstalks.keys())
                ebk = set(existing_beanstalks.keys())

                for resource_id in ebk - bk:
                    db.session.delete(
                        existing_beanstalks[resource_id].resource)
                    self.log.debug('Deleted ElasticBeanStalk {}/{}/{}'.format(
                        self.account.account_name, self.region, resource_id))
                db.session.commit()
            except:
                db.session.rollback()

            return beanstalks
        finally:
            del ebclient

    @retry
    def update_vpcs(self):
        """Update list of VPCs for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating VPCs for {}/{}'.format(
            self.account.account_name, self.region))

        existing_vpcs = VPC.get_all(self.account, self.region)
        try:
            ec2 = self.session.resource('ec2', region_name=self.region)
            ec2_client = self.session.client('ec2', region_name=self.region)
            vpcs = {x.id: x for x in ec2.vpcs.all()}

            for data in vpcs.values():
                flow_logs = ec2_client.describe_flow_logs(
                    Filters=[{
                        'Name': 'resource-id',
                        'Values': [data.vpc_id]
                    }]).get('FlowLogs')

                tags = {t['Key']: t['Value'] for t in data.tags or {}}
                properties = {
                    'vpc_id':
                    data.vpc_id,
                    'cidr_v4':
                    data.cidr_block,
                    'is_default':
                    data.is_default,
                    'state':
                    data.state,
                    'vpc_flow_logs_status':
                    flow_logs[0]['FlowLogStatus']
                    if flow_logs else 'UNDEFINED',
                    'vpc_flow_logs_log_group':
                    flow_logs[0]['LogGroupName'] if flow_logs else 'UNDEFINED',
                    'tags':
                    tags
                }
                if data.id in existing_vpcs:
                    vpc = existing_vpcs[data.vpc_id]
                    if vpc.update(data, properties):
                        self.log.debug(
                            'Change detected for VPC {}/{}/{} '.format(
                                data.vpc_id, self.region, properties))
                else:
                    VPC.create(data.id,
                               account_id=self.account.account_id,
                               location=self.region,
                               properties=properties,
                               tags=tags)
            db.session.commit()

            # Removal of VPCs
            vk = set(vpcs.keys())
            evk = set(existing_vpcs.keys())

            for resource_id in evk - vk:
                db.session.delete(existing_vpcs[resource_id].resource)
                self.log.debug('Removed VPCs {}/{}/{}'.format(
                    self.account.account_name, self.region, resource_id))
            db.session.commit()

        except Exception:
            self.log.exception(
                'There was a problem during VPC collection for {}/{}'.format(
                    self.account.account_name, self.region))
            db.session.rollback()

    @retry
    def update_elbs(self):
        """Update list of ELBs for the account / region

        Returns:
            `None`
        """
        self.log.debug('Updating ELBs for {}/{}'.format(
            self.account.account_name, self.region))

        # ELBs known to CINQ
        elbs_from_db = ELB.get_all(self.account, self.region)
        try:

            # ELBs known to AWS
            elb_client = self.session.client('elb', region_name=self.region)
            load_balancer_instances = elb_client.describe_load_balancers(
            )['LoadBalancerDescriptions']
            elbs_from_api = {}
            for load_balancer in load_balancer_instances:
                key = '{}::{}'.format(self.region,
                                      load_balancer['LoadBalancerName'])
                elbs_from_api[key] = load_balancer

            # Process ELBs known to AWS
            for elb_identifier in elbs_from_api:
                data = elbs_from_api[elb_identifier]
                # ELB already in DB?
                if elb_identifier in elbs_from_db:
                    elb = elbs_from_db[elb_identifier]
                    if elb.update(data):
                        self.log.info(
                            'Updating info for ELB {} in {}/{}'.format(
                                elb.resource.resource_id,
                                self.account.account_name, self.region))
                        db.session.add(elb.resource)
                else:
                    # Not previously seen this ELB, so add it
                    if 'Tags' in data:
                        try:
                            tags = {
                                tag['Key']: tag['Value']
                                for tag in data['Tags']
                            }
                        except AttributeError:
                            tags = {}
                    else:
                        tags = {}

                    vpc_data = (data['VPCId'] if
                                ('VPCId' in data
                                 and data['VPCId']) else 'no vpc')

                    properties = {
                        'lb_name':
                        data['LoadBalancerName'],
                        'dns_name':
                        data['DNSName'],
                        'instances':
                        ' '.join([
                            instance['InstanceId']
                            for instance in data['Instances']
                        ]),
                        'num_instances':
                        len([
                            instance['InstanceId']
                            for instance in data['Instances']
                        ]),
                        'vpc_id':
                        vpc_data,
                        'state':
                        'not_reported'
                    }
                    if 'CanonicalHostedZoneName' in data:
                        properties['canonical_hosted_zone_name'] = data[
                            'CanonicalHostedZoneName']
                    else:
                        properties['canonical_hosted_zone_name'] = None

                    # LoadBalancerName doesn't have to be unique across all regions
                    # Use region::LoadBalancerName as resource_id
                    resource_id = '{}::{}'.format(self.region,
                                                  data['LoadBalancerName'])

                    # All done, create
                    elb = ELB.create(resource_id,
                                     account_id=self.account.account_id,
                                     location=self.region,
                                     properties=properties,
                                     tags=tags)

                    # elbs[elb.resource.resource_id] = elb
                    self.log.info('Added new ELB {}/{}/{}'.format(
                        self.account.account_name, self.region,
                        elb.resource.resource_id))

            # Delete no longer existing ELBs
            elb_keys_from_db = set(list(elbs_from_db.keys()))
            self.log.debug('elb_keys_from_db =  %s', elb_keys_from_db)
            elb_keys_from_api = set(list(elbs_from_api.keys()))
            self.log.debug('elb_keys_from_api = %s', elb_keys_from_api)

            for elb_identifier in elb_keys_from_db - elb_keys_from_api:
                db.session.delete(elbs_from_db[elb_identifier].resource)
                self.log.info('Deleted ELB {}/{}/{}'.format(
                    self.account.account_name, self.region, elb_identifier))
            db.session.commit()

        except:
            self.log.exception(
                'There was a problem during ELB collection for {}/{}'.format(
                    self.account.account_name, self.region))
            db.session.rollback()
Exemple #16
0
class AWSAccountCollector(BaseCollector):
    name = 'AWS Account Collector'
    ns = 'collector_ec2'
    type = CollectorType.AWS_ACCOUNT
    interval = dbconfig.get('interval', ns, 15)
    s3_collection_enabled = dbconfig.get('s3_bucket_collection', ns, True)
    cloudfront_collection_enabled = dbconfig.get('cloudfront_collection', ns, True)
    route53_collection_enabled = dbconfig.get('route53_collection', ns, True)

    options = (
        ConfigOption('s3_bucket_collection', True, 'bool', 'Enable S3 Bucket Collection'),
        ConfigOption('cloudfront_collection', True, 'bool', 'Enable Cloudfront DNS Collection'),
        ConfigOption('route53_collection', True, 'bool', 'Enable Route53 DNS Collection'),
    )

    def __init__(self, account):
        super().__init__()

        if type(account) == str:
            account = AWSAccount.get(account)

        if not isinstance(account, AWSAccount):
            raise InquisitorError('The AWS Collector only supports AWS Accounts, got {}'.format(
                account.__class__.__name__
            ))

        self.account = account
        self.session = get_aws_session(self.account)

    def run(self):
        try:
            if self.s3_collection_enabled:
                self.update_s3buckets()

            if self.cloudfront_collection_enabled:
                self.update_cloudfront()

            if self.route53_collection_enabled:
                self.update_route53()

        except Exception as ex:
            self.log.exception(ex)
            raise

        finally:
            del self.session

    @retry
    def update_s3buckets(self):
        """Update list of S3 Buckets for the account

        Returns:
            `None`
        """
        self.log.debug('Updating S3Buckets for {}'.format(self.account.account_name))
        s3 = self.session.resource('s3')
        s3c = self.session.client('s3')

        try:
            existing_buckets = S3Bucket.get_all(self.account)
            buckets = {bucket.name: bucket for bucket in s3.buckets.all()}
            for data in buckets.values():
                # This section ensures that we handle non-existent or non-accessible sub-resources
                try:
                    bucket_region = s3c.get_bucket_location(Bucket=data.name)['LocationConstraint']
                    if not bucket_region:
                        bucket_region = 'us-east-1'

                except ClientError as e:
                    self.log.info('Could not get bucket location..bucket possibly removed / {}'.format(e))
                    bucket_region = 'unavailable'

                try:
                    bucket_policy = data.Policy().policy

                except ClientError as e:
                    if e.response['Error']['Code'] == 'NoSuchBucketPolicy':
                        bucket_policy = None
                    else:
                        self.log.info('There was a problem collecting bucket policy for bucket {} on account {}, {}'
                                      .format(data.name, self.account, e.response))
                        bucket_policy = 'cinq cannot poll'

                try:
                    website_enabled = 'Enabled' if data.Website().index_document else 'Disabled'

                except ClientError as e:
                    if e.response['Error']['Code'] == 'NoSuchWebsiteConfiguration':
                        website_enabled = 'Disabled'
                    else:
                        self.log.info('There was a problem collecting website config for bucket {} on account {}'
                                      .format(data.name, self.account))
                        website_enabled = 'cinq cannot poll'

                try:
                    tags = {t['Key']: t['Value'] for t in data.Tagging().tag_set}

                except ClientError:
                    tags = {}

                try:
                    bucket_size = self._get_bucket_statistics(data.name, bucket_region, 'StandardStorage',
                                                              'BucketSizeBytes', 3)

                    bucket_obj_count = self._get_bucket_statistics(data.name, bucket_region, 'AllStorageTypes',
                                                                   'NumberOfObjects', 3)

                    metrics = {'size': bucket_size, 'object_count': bucket_obj_count}

                except Exception as e:
                    self.log.info('Could not retrieve bucket statistics / {}'.format(e))
                    metrics = {'found': False}

                properties = {
                    'bucket_policy': bucket_policy,
                    'creation_date': data.creation_date,
                    'location': bucket_region,
                    'website_enabled': website_enabled,
                    'metrics': metrics,
                    'tags': tags
                }

                if data.name in existing_buckets:
                    bucket = existing_buckets[data.name]
                    if bucket.update(data, properties):
                        self.log.debug('Change detected for S3Bucket {}/{}'.format(
                            self.account.account_name,
                            bucket.id
                        ))
                        bucket.save()
                else:
                    # If a bucket has no tags, a boto3 error is thrown. We treat this as an empty tag set

                    S3Bucket.create(
                        data.name,
                        account_id=self.account.account_id,
                        properties=properties,
                        location=bucket_region,
                        tags=tags
                    )
                    self.log.debug('Added new S3Bucket {}/{}'.format(
                        self.account.account_name,
                        data.name
                    ))
            db.session.commit()

            bk = set(list(buckets.keys()))
            ebk = set(list(existing_buckets.keys()))

            try:
                for resource_id in ebk - bk:
                    db.session.delete(existing_buckets[resource_id].resource)
                    self.log.debug('Deleted S3Bucket {}/{}'.format(
                        self.account.account_name,
                        resource_id
                    ))
                db.session.commit()

            except Exception as e:
                self.log.error(
                    'Could not update the current S3Bucket list for account {}/{}'.format(self.account.account_name, e))
                db.session.rollback()

        finally:
            del s3, s3c

    @retry
    def update_cloudfront(self):
        """Update list of CloudFront Distributions for the account

        Returns:
            `None`
        """
        self.log.debug('Updating CloudFront distributions for {}'.format(self.account.account_name))
        cfr = self.session.client('cloudfront')

        try:
            existing_dists = CloudFrontDist.get_all(self.account, None)
            dists = []

            # region Fetch information from API
            # region Web distributions
            done = False
            marker = None
            while not done:
                if marker:
                    response = cfr.list_distributions(Marker=marker)
                else:
                    response = cfr.list_distributions()

                dl = response['DistributionList']
                if dl['IsTruncated']:
                    marker = dl['NextMarker']
                else:
                    done = True

                if 'Items' in dl:
                    for dist in dl['Items']:
                        origins = []
                        for origin in dist['Origins']['Items']:
                            if 'S3OriginConfig' in origin:
                                origins.append(
                                    {
                                        'type': 's3',
                                        'source': origin['DomainName']
                                    }
                                )
                            elif 'CustomOriginConfig' in origin:
                                origins.append(
                                    {
                                        'type': 'custom-http',
                                        'source': origin['DomainName']
                                    }
                                )

                        data = {
                            'arn': dist['ARN'],
                            'name': dist['DomainName'],
                            'origins': origins,
                            'enabled': dist['Enabled'],
                            'type': 'web',
                            'tags': self.__get_distribution_tags(cfr, dist['ARN'])
                        }
                        dists.append(data)
            # endregion

            # region Streaming distributions
            done = False
            marker = None
            while not done:
                if marker:
                    response = cfr.list_streaming_distributions(Marker=marker)
                else:
                    response = cfr.list_streaming_distributions()

                dl = response['StreamingDistributionList']
                if dl['IsTruncated']:
                    marker = dl['NextMarker']
                else:
                    done = True

                if 'Items' in dl:
                    dists += [
                        {
                            'arn': x['ARN'],
                            'name': x['DomainName'],
                            'origins': [{'type': 's3', 'source': x['S3Origin']['DomainName']}],
                            'enabled': x['Enabled'],
                            'type': 'rtmp',
                            'tags': self.__get_distribution_tags(cfr, x['ARN'])
                        } for x in dl['Items']
                    ]
            # endregion
            # endregion

            for data in dists:
                if data['arn'] in existing_dists:
                    dist = existing_dists[data['arn']]
                    if dist.update(data):
                        self.log.debug('Updated CloudFrontDist {}/{}'.format(
                            self.account.account_name,
                            data['name']
                        ))
                        dist.save()

                else:
                    properties = {
                        'domain_name': data['name'],
                        'origins': data['origins'],
                        'enabled': data['enabled'],
                        'type': data['type']
                    }

                    CloudFrontDist.create(
                        data['arn'],
                        account_id=self.account.account_id,
                        properties=properties,
                        tags=data['tags']
                    )

                    self.log.debug('Added new CloudFrontDist {}/{}'.format(
                        self.account.account_name,
                        data['name']
                    ))
            db.session.commit()

            dk = set(x['arn'] for x in dists)
            edk = set(existing_dists.keys())

            try:
                for resource_id in edk - dk:
                    db.session.delete(existing_dists[resource_id].resource)
                    self.log.debug('Deleted CloudFrontDist {}/{}'.format(
                        resource_id,
                        self.account.account_name
                    ))
                db.session.commit()
            except:
                db.session.rollback()
        finally:
            del cfr

    @retry
    def update_route53(self):
        """Update list of Route53 DNS Zones and their records for the account

        Returns:
            `None`
        """
        self.log.debug('Updating Route53 information for {}'.format(self.account))

        # region Update zones
        existing_zones = DNSZone.get_all(self.account)
        zones = self.__fetch_route53_zones()
        for resource_id, data in zones.items():
            if resource_id in existing_zones:
                zone = DNSZone.get(resource_id)
                if zone.update(data):
                    self.log.debug('Change detected for Route53 zone {}/{}'.format(
                        self.account,
                        zone.name
                    ))
                    zone.save()
            else:
                tags = data.pop('tags')
                DNSZone.create(
                    resource_id,
                    account_id=self.account.account_id,
                    properties=data,
                    tags=tags
                )

                self.log.debug('Added Route53 zone {}/{}'.format(
                    self.account,
                    data['name']
                ))

        db.session.commit()

        zk = set(zones.keys())
        ezk = set(existing_zones.keys())

        for resource_id in ezk - zk:
            zone = existing_zones[resource_id]

            db.session.delete(zone.resource)
            self.log.debug('Deleted Route53 zone {}/{}'.format(
                self.account.account_name,
                zone.name.value
            ))
        db.session.commit()
        # endregion

        # region Update resource records
        try:
            for zone_id, zone in DNSZone.get_all(self.account).items():
                existing_records = {rec.id: rec for rec in zone.records}
                records = self.__fetch_route53_zone_records(zone.get_property('zone_id').value)

                for data in records:
                    if data['id'] in existing_records:
                        record = existing_records[data['id']]
                        if record.update(data):
                            self.log.debug('Changed detected for DNSRecord {}/{}/{}'.format(
                                self.account,
                                zone.name,
                                data['name']
                            ))
                            record.save()
                    else:
                        record = DNSRecord.create(
                            data['id'],
                            account_id=self.account.account_id,
                            properties={k: v for k, v in data.items() if k != 'id'},
                            tags={}
                        )
                        self.log.debug('Added new DNSRecord {}/{}/{}'.format(
                            self.account,
                            zone.name,
                            data['name']
                        ))
                        zone.add_record(record)
                db.session.commit()

                rk = set(x['id'] for x in records)
                erk = set(existing_records.keys())

                for resource_id in erk - rk:
                    record = existing_records[resource_id]
                    zone.delete_record(record)
                    self.log.debug('Deleted Route53 record {}/{}/{}'.format(
                        self.account.account_name,
                        zone_id,
                        record.name
                    ))
                db.session.commit()
        except:
            raise
        # endregion

    # region Helper functions
    @retry
    def __get_distribution_tags(self, client, arn):
        """Returns a dict containing the tags for a CloudFront distribution

        Args:
            client (botocore.client.CloudFront): Boto3 CloudFront client object
            arn (str): ARN of the distribution to get tags for

        Returns:
            `dict`
        """
        return {
            t['Key']: t['Value'] for t in client.list_tags_for_resource(
            Resource=arn
        )['Tags']['Items']
        }

    @retry
    def __fetch_route53_zones(self):
        """Return a list of all DNS zones hosted in Route53

        Returns:
            :obj:`list` of `dict`
        """
        done = False
        marker = None
        zones = {}
        route53 = self.session.client('route53')

        try:
            while not done:
                if marker:
                    response = route53.list_hosted_zones(Marker=marker)
                else:
                    response = route53.list_hosted_zones()

                if response['IsTruncated']:
                    marker = response['NextMarker']
                else:
                    done = True

                for zone_data in response['HostedZones']:
                    zones[get_resource_id('r53z', zone_data['Id'])] = {
                        'name': zone_data['Name'].rstrip('.'),
                        'source': 'AWS/{}'.format(self.account),
                        'comment': zone_data['Config']['Comment'] if 'Comment' in zone_data['Config'] else None,
                        'zone_id': zone_data['Id'],
                        'private_zone': zone_data['Config']['PrivateZone'],
                        'tags': self.__fetch_route53_zone_tags(zone_data['Id'])
                    }

            return zones
        finally:
            del route53

    @retry
    def __fetch_route53_zone_records(self, zone_id):
        """Return all resource records for a specific Route53 zone

        Args:
            zone_id (`str`): Name / ID of the hosted zone

        Returns:
            `dict`
        """
        route53 = self.session.client('route53')

        done = False
        nextName = nextType = None
        records = {}

        try:
            while not done:
                if nextName and nextType:
                    response = route53.list_resource_record_sets(
                        HostedZoneId=zone_id,
                        StartRecordName=nextName,
                        StartRecordType=nextType
                    )
                else:
                    response = route53.list_resource_record_sets(HostedZoneId=zone_id)

                if response['IsTruncated']:
                    nextName = response['NextRecordName']
                    nextType = response['NextRecordType']
                else:
                    done = True

                if 'ResourceRecordSets' in response:
                    for record in response['ResourceRecordSets']:
                        # Cannot make this a list, due to a race-condition in the AWS api that might return the same
                        # record more than once, so we use a dict instead to ensure that if we get duplicate records
                        # we simply just overwrite the one already there with the same info.
                        record_id = self._get_resource_hash(zone_id, record)
                        if 'AliasTarget' in record:
                            value = record['AliasTarget']['DNSName']
                            records[record_id] = {
                                'id': record_id,
                                'name': record['Name'].rstrip('.'),
                                'type': 'ALIAS',
                                'ttl': 0,
                                'value': [value]
                            }
                        else:
                            value = [y['Value'] for y in record['ResourceRecords']]
                            records[record_id] = {
                                'id': record_id,
                                'name': record['Name'].rstrip('.'),
                                'type': record['Type'],
                                'ttl': record['TTL'],
                                'value': value
                            }

            return list(records.values())
        finally:
            del route53

    @retry
    def __fetch_route53_zone_tags(self, zone_id):
        """Return a dict with the tags for the zone

        Args:
            zone_id (`str`): ID of the hosted zone

        Returns:
            :obj:`dict` of `str`: `str`
        """
        route53 = self.session.client('route53')

        try:
            return {
                tag['Key']: tag['Value'] for tag in
                route53.list_tags_for_resource(
                    ResourceType='hostedzone',
                    ResourceId=zone_id.split('/')[-1]
                )['ResourceTagSet']['Tags']
            }
        finally:
            del route53

    @staticmethod
    def _get_resource_hash(zone_name, record):
        """Returns the last ten digits of the sha256 hash of the combined arguments. Useful for generating unique
        resource IDs

        Args:
            zone_name (`str`): The name of the DNS Zone the record belongs to
            record (`dict`): A record dict to generate the hash from

        Returns:
            `str`
        """
        record_data = defaultdict(int, record)
        if type(record_data['GeoLocation']) == dict:
            record_data['GeoLocation'] = ":".join(["{}={}".format(k, v) for k, v in record_data['GeoLocation'].items()])

        args = [
            zone_name,
            record_data['Name'],
            record_data['Type'],
            record_data['Weight'],
            record_data['Region'],
            record_data['GeoLocation'],
            record_data['Failover'],
            record_data['HealthCheckId'],
            record_data['TrafficPolicyInstanceId']
        ]

        return get_resource_id('r53r', args)

    def _get_bucket_statistics(self, bucket_name, bucket_region, storage_type, statistic, days):
        """ Returns datapoints from cloudwatch for bucket statistics.

        Args:
            bucket_name `(str)`: The name of the bucket
            statistic `(str)`: The statistic you want to fetch from
            days `(int)`: Sample period for the statistic

        """

        cw = self.session.client('cloudwatch', region_name=bucket_region)

        # gather cw stats

        try:
            obj_stats = cw.get_metric_statistics(
                Namespace='AWS/S3',
                MetricName=statistic,
                Dimensions=[
                    {
                        'Name': 'StorageType',
                        'Value': storage_type
                    },
                    {
                        'Name': 'BucketName',
                        'Value': bucket_name
                    }
                ],
                Period=86400,
                StartTime=datetime.utcnow() - timedelta(days=days),
                EndTime=datetime.utcnow(),
                Statistics=[
                    'Average'
                ]
            )
            stat_value = obj_stats['Datapoints'][0]['Average'] if obj_stats['Datapoints'] else 'NO_DATA'

            return stat_value

        except Exception as e:
            self.log.error(
                'Could not get bucket statistic for account {} / bucket {} / {}'.format(self.account.account_name,
                                                                                        bucket_name, e))

        finally:
            del cw
Exemple #17
0
class StandaloneScheduler(BaseScheduler):
    """Main workers refreshing data from AWS
    """
    name = 'Standalone Scheduler'
    ns = 'scheduler_standalone'
    pool = None
    scheduler = None
    options = (
        ConfigOption('worker_threads', 20, 'int', 'Number of worker threads to spawn'),
        ConfigOption('worker_interval', 30, 'int', 'Delay between each worker thread being spawned, in seconds'),
    )

    def __init__(self):
        super().__init__()
        self.collectors = {}
        self.auditors = []
        self.region_workers = []

        self.pool = ProcessPoolExecutor(self.dbconfig.get('worker_threads', self.ns, 20))
        self.scheduler = APScheduler(
            threadpool=self.pool,
            job_defaults={
                'coalesce': True,
                'misfire_grace_time': 30
            }
        )

        self.load_plugins()

    def execute_scheduler(self):
        # Schedule a daily job to cleanup stuff thats been left around (eip's with no instances etc)
        self.scheduler.add_job(
            self.cleanup,
            trigger='cron',
            name='cleanup',
            hour=3,
            minute=0,
            second=0
        )

        # Schedule periodic scheduling of jobs
        self.scheduler.add_job(
            self.schedule_jobs,
            trigger='interval',
            name='schedule_jobs',
            seconds=60,
            start_date=datetime.now() + timedelta(seconds=1)
        )

        # Periodically reload the dbconfiguration
        self.scheduler.add_job(
            self.dbconfig.reload_data,
            trigger='interval',
            name='reload_dbconfig',
            minutes=5,
            start_date=datetime.now() + timedelta(seconds=3)
        )

        self.scheduler.start()

    def execute_worker(self):
        """This method is not used for the standalone scheduler."""
        print('The standalone scheduler does not have a separate worker model. '
              'Executing the scheduler will also execute the workers')

    def schedule_jobs(self):
        current_jobs = {
            x.name: x for x in self.scheduler.get_jobs() if x.name not in (
                'cleanup',
                'schedule_jobs',
                'reload_dbconfig'
            )
        }
        new_jobs = []
        start = datetime.now() + timedelta(seconds=1)
        _, accounts = BaseAccount.search(include_disabled=False)

        # region Global collectors (non-aws)
        if CollectorType.GLOBAL in self.collectors:
            for wkr in self.collectors[CollectorType.GLOBAL]:
                job_name = 'global_{}'.format(wkr.name)
                new_jobs.append(job_name)

                if job_name in current_jobs:
                    continue

                self.scheduler.add_job(
                    self.execute_global_worker,
                    trigger='interval',
                    name=job_name,
                    minutes=wkr.interval,
                    start_date=start,
                    args=[wkr],
                    kwargs={}
                )

                start += timedelta(seconds=30)
        # endregion

        # region AWS collectors
        aws_accounts = list(filter(lambda x: x.account_type == AWSAccount.account_type, accounts))
        for acct in aws_accounts:
            if CollectorType.AWS_ACCOUNT in self.collectors:
                for wkr in self.collectors[CollectorType.AWS_ACCOUNT]:
                    job_name = '{}_{}'.format(acct.account_name, wkr.name)
                    new_jobs.append(job_name)

                    if job_name in current_jobs:
                        continue

                    self.scheduler.add_job(
                        self.execute_aws_account_worker,
                        trigger='interval',
                        name=job_name,
                        minutes=wkr.interval,
                        start_date=start,
                        args=[wkr],
                        kwargs={'account': acct.account_name}
                    )

            if CollectorType.AWS_REGION in self.collectors:
                for wkr in self.collectors[CollectorType.AWS_REGION]:
                    for region in AWS_REGIONS:
                        job_name = '{}_{}_{}'.format(acct.account_name, region, wkr.name)
                        new_jobs.append(job_name)

                        if job_name in current_jobs:
                            continue

                        self.scheduler.add_job(
                            self.execute_aws_region_worker,
                            trigger='interval',
                            name=job_name,
                            minutes=wkr.interval,
                            start_date=start,
                            args=[wkr],
                            kwargs={'account': acct.account_name, 'region': region}
                        )
            db.session.commit()
            start += timedelta(seconds=self.dbconfig.get('worker_interval', self.ns, 30))
        # endregion

        # region Auditors
        start = datetime.now() + timedelta(seconds=1)
        for wkr in self.auditors:
            job_name = 'auditor_{}'.format(wkr.name)
            new_jobs.append(job_name)

            if job_name in current_jobs:
                continue

            if app_config.log_level == 'DEBUG':
                audit_start = start + timedelta(seconds=5)
            else:
                audit_start = start + timedelta(minutes=5)

            self.scheduler.add_job(
                self.execute_auditor_worker,
                trigger='interval',
                name=job_name,
                minutes=wkr.interval,
                start_date=audit_start,
                args=[wkr],
                kwargs={}
            )
            start += timedelta(seconds=self.dbconfig.get('worker_interval', self.ns, 30))
        # endregion

        extra_jobs = list(set(current_jobs) - set(new_jobs))
        for job in extra_jobs:
            self.log.warning('Removing job {} as it is no longer needed'.format(job))
            current_jobs[job].remove()

    def execute_global_worker(self, data, **kwargs):
        try:
            cls = self.get_class_from_ep(data.entry_point)
            worker = cls(**kwargs)
            self.log.info('RUN_INFO: {} starting at {}, next run will be at approximately {}'.format(data.entry_point['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), (datetime.now() + timedelta(minutes=data.interval)).strftime("%Y-%m-%d %H:%M:%S")))
            self.log.info('Starting global {} worker'.format(data.name))
            worker.run()

        except Exception as ex:
            self.log.exception('Global Worker {}: {}'.format(data.name, ex))

        finally:
            db.session.rollback()
            self.log.info('Completed run for global {} worker'.format(data.name))

    def execute_aws_account_worker(self, data, **kwargs):
        try:
            cls = self.get_class_from_ep(data.entry_point)
            worker = cls(**kwargs)
            self.log.info('RUN_INFO: {} starting at {}, next run will be at approximately {}'.format(data.entry_point['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), (datetime.now() + timedelta(minutes=data.interval)).strftime("%Y-%m-%d %H:%M:%S")))
            worker.run()

        except Exception as ex:
            self.log.exception('AWS Account Worker {}/{}: {}'.format(data.name, kwargs['account'], ex))

        finally:
            db.session.rollback()
            self.log.info('Completed run for {} worker on {}'.format(data.name, kwargs['account']))

    def execute_aws_region_worker(self, data, **kwargs):
        try:
            cls = self.get_class_from_ep(data.entry_point)
            worker = cls(**kwargs)
            self.log.info('RUN_INFO: {} starting at {} for account {} / region {}, next run will be at approximately {}'.format(data.entry_point['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), kwargs['account'], kwargs['region'], (datetime.now() + timedelta(minutes=data.interval)).strftime("%Y-%m-%d %H:%M:%S")))
            worker.run()

        except Exception as ex:
            self.log.exception('AWS Region Worker {}/{}/{}: {}'.format(
                data.name,
                kwargs['account'],
                kwargs['region'],
                ex
            ))

        finally:
            db.session.rollback()
            self.log.info('Completed run for {} worker on {}/{}'.format(
                data.name,
                kwargs['account'],
                kwargs['region']
            ))

    def execute_auditor_worker(self, data, **kwargs):
        try:
            cls = self.get_class_from_ep(data.entry_point)
            worker = cls(**kwargs)
            self.log.info('RUN_INFO: {} starting at {}, next run will be at approximately {}'.format(data.entry_point['module_name'], datetime.now().strftime("%Y-%m-%d %H:%M:%S"), (datetime.now() + timedelta(minutes=data.interval)).strftime("%Y-%m-%d %H:%M:%S")))
            worker.run()

        except Exception as ex:
            self.log.exception('Auditor Worker {}: {}'.format(data.name, ex))

        finally:
            db.session.rollback()
            self.log.info('Completed run for auditor {}'.format(data.name))

    def cleanup(self):
        try:
            self.log.info('Running cleanup tasks')

            log_purge_date = datetime.now() - timedelta(days=self.dbconfig.get('log_keep_days', 'log', default=31))
            db.LogEvent.find(LogEvent.timestamp < log_purge_date)

            db.session.commit()
        finally:
            db.session.rollback()