def test_short_circuit_without_exclude_list(self, is_excluded_ioc): """Threat Intel - ensure we skip threat intel exclusion if there are no excluded_iocs""" self.config = load_config('tests/unit/conf') self.config['global']['threat_intel']['enabled'] = True del self.config['global']['threat_intel']['excluded_iocs'] self.threat_intel = StreamThreatIntel.load_from_config(self.config) records = [{ 'account': 12345, 'region': '123456123456', 'detail': { 'eventType': 'AwsConsoleSignIn', 'eventName': 'ConsoleLogin', 'userIdentity': { 'userName': '******', 'type': 'Root', 'principalId': '12345', }, 'sourceIPAddress': '8.8.8.8', 'recipientAccountId': '12345' }, 'source': '8.8.8.8', 'streamalert:normalization': { 'sourceAddress': [['detail', 'sourceIPAddress'], ['source']], 'usernNme': [['detail', 'userIdentity', 'userName']] }, 'id': '12345' }] records = mock_normalized_records(records) for record in records: result = self.threat_intel._extract_ioc_from_record(record) assert_equal(len(result), 1) assert_equal(result[0].value, '8.8.8.8') assert not is_excluded_ioc.called
def setup(self): """LookupTables - Setup S3 bucket mocking""" self.buckets_info = {'bucket_name': ['foo.json', 'bar.json']} self.config = load_config('tests/unit/conf') self.s3_mock = mock_s3() self.s3_mock.start() self._put_mock_tables()
def __init__(self, context): """Initializer Args: context (dict): An AWS context object which provides metadata on the currently executing lambda function. """ # Load the config. Validation occurs during load, which will # raise exceptions on any ConfigError StreamAlert.config = StreamAlert.config or config.load_config(validate=True) # Load the environment from the context arn self.env = config.parse_lambda_arn(context.invoked_function_arn) # Instantiate the send_alerts here to handle sending the triggered alerts to the # alert processor self.alert_forwarder = AlertForwarder() # Instantiate a classifier that is used for this run self.classifier = StreamClassifier(config=self.config) self._failed_record_count = 0 self._processed_record_count = 0 self._processed_size = 0 self._alerts = [] rule_import_paths = [item for location in {'rule_locations', 'matcher_locations'} for item in self.config['global']['general'][location]] # Create an instance of the RulesEngine class that gets cached in the # StreamAlert class as an instance property self._rules_engine = RulesEngine(self.config, *rule_import_paths) # Firehose client attribute self._firehose_client = None
def test_load_from_config_with_cluster_env(self): """Threat Intel - Test load_from_config to read cluster""" config = load_config('tests/unit/conf') config['global']['threat_intel']['enabled'] = True threat_intel = StreamThreatIntel.load_from_config(config) assert_is_instance(threat_intel, StreamThreatIntel) assert_true('advanced' in config['clusters'].keys())
def test_load_include(): """Shared - Config Loading - Include""" config = load_config(include={'clusters', 'logs.json'}) expected_keys = ['clusters', 'logs'] expected_clusters_keys = ['prod', 'dev'] assert_items_equal(config.keys(), expected_keys) assert_items_equal(config['clusters'].keys(), expected_clusters_keys)
def __init__(self): """Initialization logic that can be cached across invocations""" # Merge user-specified output configuration with the required output configuration output_config = load_config(include={'outputs.json'})['outputs'] self.config = resources.merge_required_outputs(output_config, env['STREAMALERT_PREFIX']) self.alerts_table = AlertTable(env['ALERTS_TABLE'])
def test_load_exclude_clusters(): """Shared - Config Loading - Exclude Clusters""" config = load_config(exclude={'clusters'}) expected_keys = [ 'global', 'lambda', 'logs', 'outputs', 'sources', 'types' ] assert_items_equal(config.keys(), expected_keys)
def __init__(self): config = load_config() prefix = config['global']['account']['prefix'] # Create the rule table class for getting staging information self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix)) athena_config = config['lambda']['athena_partition_refresh_config'] # Get the name of the athena database to access db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix) ) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._current_time = datetime.utcnow() # Store the SNS topic arn to send alert stat information to self._publisher = StatsPublisher(config, self._athena_client, self._current_time) self._staging_stats = dict()
def setup(self): """Add a fake message to the queue.""" self.mock_sqs = mock_sqs() self.mock_sqs.start() sqs = boto3.resource('sqs') config = load_config('tests/unit/conf/') prefix = config['global']['account']['prefix'] name = StreamAlertSQSClient.DEFAULT_QUEUE_NAME.format(prefix) self.queue = sqs.create_queue(QueueName=name) # Create a fake s3 notification message to send bucket = 'unit-testing.streamalerts' test_s3_notification = { 'Records': [ { 'eventVersion': '2.0', 'eventSource': 'aws:s3', 'awsRegion': 'us-east-1', 'eventTime': '2017-08-07T18:26:30.956Z', 'eventName': 'S3:PutObject', 'userIdentity': { 'principalId': 'AWS:AAAAAAAAAAAAAAA' }, 'requestParameters': { 'sourceIPAddress': '127.0.0.1' }, 'responseElements': { 'x-amz-request-id': 'FOO', 'x-amz-id-2': 'BAR' }, 's3': { 's3SchemaVersion': '1.0', 'configurationId': 'queue', 'bucket': { 'name': bucket, 'ownerIdentity': { 'principalId': 'AAAAAAAAAAAAAAA' }, 'arn': 'arn:aws:s3:::{}'.format(bucket) }, 'object': { 'key': ('alerts/dt=2017-08-2{}-14-02/rule_name_alerts-' '1304134918401.json'.format(day)), 'size': 1494, 'eTag': '12214134141431431', 'versionId': 'asdfasdfasdf.dfadCJkj1', 'sequencer': '1212312321312321321' } } } for day in {6, 7} ] } self.queue.send_message(MessageBody=json.dumps(test_s3_notification)) self.client = StreamAlertSQSClient(config)
def test_load_all(): """Shared - Config Loading - All""" config = load_config() expected_keys = [ 'clusters', 'global', 'lambda', 'logs', 'outputs', 'sources', 'types' ] assert_items_equal(config.keys(), expected_keys)
def setup(self): """Setup before each method""" # Clear out the cached matchers and rules to avoid conflicts with production code Matcher._matchers.clear() Rule._rules.clear() self.config = load_config('tests/unit/conf') self.config['global']['threat_intel']['enabled'] = False self.rules_engine = RulesEngine(self.config)
def setup(self): """StatsPublisher - Setup""" # pylint: disable=attribute-defined-outside-init self.publisher = StatsPublisher( config=config.load_config('tests/unit/conf/'), athena_client=None, current_time=datetime(year=2000, month=1, day=1, hour=1, minute=1, second=1) )
def setup(self): """Setup the StreamAlertAthenaClient tests""" self._db_name = 'test_database' config = load_config('tests/unit/conf/') prefix = config['global']['account']['prefix'] self.client = StreamAlertAthenaClient( self._db_name, 's3://{}.streamalert.athena-results'.format(prefix), 'unit-testing')
def test_load_enabled_sources_invalid_log(self, mock_logging): """StreamAlertFirehose - Load Enabled Sources - Invalid Log""" config = load_config('tests/unit/conf') firehose_config = {'enabled_logs': ['log-that-doesnt-exist']} sa_firehose = StreamAlertFirehose(region='us-east-1', firehose_config=firehose_config, log_sources=config['logs']) assert_equal(len(sa_firehose._enabled_logs), 0) assert_true(mock_logging.error.called)
def test_load_enabled_sources_invalid_log(self, mock_logging): """FirehoseClient - Load Enabled Sources - Invalid Log""" config = load_config('tests/unit/conf') firehose_config = {'enabled_logs': ['log-that-doesnt-exist']} sa_firehose = FirehoseClient( region='us-east-1', firehose_config=firehose_config, log_sources=config['logs']) assert_equal(len(sa_firehose._ENABLED_LOGS), 0) mock_logging.assert_called_with( 'Enabled Firehose log %s not declared in logs.json', 'log-that-doesnt-exist' )
def test_load_exclude(): """Shared - Config Loading - Exclude""" config = load_config(exclude={'global.json', 'logs.json'}) expected_keys = { 'clusters', 'lambda', 'outputs', 'sources', 'threat_intel', 'normalized_types' } assert_equal(set(config), expected_keys)
def test_load_enabled_sources(self): """FirehoseClient - Load Enabled Sources""" config = load_config('tests/unit/conf') firehose_config = { 'enabled_logs': ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch'] } # expands to 2 logs enabled_logs = FirehoseClient.load_enabled_log_sources(firehose_config, config['logs']) assert_equal(len(enabled_logs), 4) # Make sure the subtitution works properly assert_true(all([':' not in log for log in enabled_logs])) assert_false(FirehoseClient.enabled_log_source('test_inspec'))
def setup(self): """RulePromoter - Setup""" # pylint: disable=attribute-defined-outside-init self.dynamo_mock = mock_dynamodb2() self.dynamo_mock.start() with patch('stream_alert.rule_promotion.promoter.load_config') as config_mock, \ patch('stream_alert.rule_promotion.promoter.StatsPublisher', Mock()), \ patch('boto3.client', _mock_boto), \ patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}): setup_mock_rules_table(_RULES_TABLE) config_mock.return_value = config.load_config('tests/unit/conf/') self.promoter = RulePromoter() self._add_fake_stats()
def test_load_all(): """Shared - Config Loading - All""" config = load_config() expected_keys = { 'clusters', 'global', 'lambda', 'logs', 'outputs', 'sources', 'threat_intel', 'normalized_types' } assert_equal(set(config), expected_keys)
def _load_config(function_arn): """Load the Threat Intel Downloader configuration from conf/lambda.json file Returns: (dict): Configuration for Threat Intel Downloader Raises: ConfigError: For invalid or missing configuration files. """ base_config = parse_lambda_arn(function_arn) config = load_config(include={'lambda.json'})['lambda'] base_config.update(config.get('threat_intel_downloader_config', {})) return base_config
def setup(self): """LookupTables - Setup S3 bucket mocking""" # pylint: disable=attribute-defined-outside-init self.config = load_config('tests/unit/conf') self.lookup_tables = LookupTables(self.buckets_info) self.s3_mock = mock_s3() self.s3_mock.start() for bucket, files in self.buckets_info.iteritems(): for json_file in files: put_mock_s3_object( bucket, json_file, json.dumps({ '{}_key'.format(bucket): '{}_value'.format(os.path.splitext(json_file)[0]) }), self.region)
def test_load_enabled_sources(self): """StreamAlertFirehose - Load Enabled Sources""" config = load_config('tests/unit/conf') firehose_config = { 'enabled_logs': ['json:regex_key_with_envelope', 'test_cloudtrail', 'cloudwatch'] } # expands to 2 logs sa_firehose = StreamAlertFirehose(region='us-east-1', firehose_config=firehose_config, log_sources=config['logs']) assert_equal(len(sa_firehose._enabled_logs), 4) # Make sure the subtitution works properly assert_true(all([':' not in log for log in sa_firehose.enabled_logs])) assert_false(sa_firehose.enabled_log_source('test_inspec'))
def __init__(self): # Create some objects to be cached if they have not already been created Classifier._config = Classifier._config or config.load_config(validate=True) Classifier._firehose_client = ( Classifier._firehose_client or FirehoseClient.load_from_config( firehose_config=self.config['global'].get('infrastructure', {}).get('firehose', {}), log_sources=self.config['logs'] ) ) Classifier._sqs_client = Classifier._sqs_client or SQSClient() # Setup the normalization logic Normalizer.load_from_config(self.config) self._payloads = [] self._failed_record_count = 0 self._processed_size = 0
def handler(event, context): """Lambda handler""" lambda_config = load_config(include={'lambda.json'})['lambda'] config = lambda_config.get('threat_intel_downloader_config') config.update(parse_lambda_arn(context.invoked_function_arn)) threat_stream = ThreatStream(config) intelligence, next_url, continue_invoke = threat_stream.runner(event) if intelligence: LOGGER.info('Write %d IOCs to DynamoDB table', len(intelligence)) threat_stream.write_to_dynamodb_table(intelligence) if context.get_remaining_time_in_millis( ) > END_TIME_BUFFER * 1000 and continue_invoke: invoke_lambda_function(next_url, config) LOGGER.debug("Time remaining (MS): %s", context.get_remaining_time_in_millis())
def __init__(self, invoked_function_arn): """Initialization logic that can be cached across invocations. Args: invoked_function_arn (str): The ARN of the alert processor when it was invoked. This is used to calculate region, account, and prefix. """ # arn:aws:lambda:REGION:ACCOUNT:function:PREFIX_streamalert_alert_processor:production split_arn = invoked_function_arn.split(':') self.region = split_arn[3] self.account_id = split_arn[4] self.prefix = split_arn[6].split('_')[0] # Merge user-specified output configuration with the required output configuration output_config = load_config(include={'outputs.json'})['outputs'] self.config = resources.merge_required_outputs(output_config, self.prefix) self.alerts_table = AlertTable(os.environ['ALERTS_TABLE'])
def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = athena_config['buckets'] db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)) self._s3_buckets_and_keys = defaultdict(set) self._create_client(db_name, results_bucket)
def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = athena_config['buckets'] db_name = athena_config.get( 'database_name', self.STREAMALERT_DATABASE.format(prefix)).strip() # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}.streamalert.athena-results'.format(prefix)).strip() self._athena_client = StreamAlertAthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) # Initialize the SQS client and recieve messages self._sqs_client = StreamAlertSQSClient(config)
def handler(*_): """Athena Partition Refresher Handler Function""" config = load_config(include={'lambda.json', 'global.json'}) # Initialize the SQS client and recieve messages stream_alert_sqs = StreamAlertSQSClient(config) # Get the first batch of messages from SQS. If there are no # messages, this will exit early. stream_alert_sqs.get_messages(max_tries=2) if not stream_alert_sqs.received_messages: LOGGER.info('No SQS messages recieved, exiting') return # If the max amount of messages was initially returned, # then get the next batch of messages. The max is determined based # on (number of tries) * (number of possible max messages returned) if len(stream_alert_sqs.received_messages) == 20: stream_alert_sqs.get_messages(max_tries=8) s3_buckets_and_keys = stream_alert_sqs.unique_s3_buckets_and_keys() if not s3_buckets_and_keys: LOGGER.error('No new Athena partitions to add, exiting') return # Initialize the Athena client and run queries stream_alert_athena = StreamAlertAthenaClient(config) # Check that the 'streamalert' database exists before running queries if not stream_alert_athena.check_database_exists(): raise AthenaPartitionRefreshError( 'The \'{}\' database does not exist'.format( stream_alert_athena.sa_database)) if not stream_alert_athena.add_partition(s3_buckets_and_keys): LOGGER.error('Failed to add hive partition(s)') return stream_alert_sqs.delete_messages() LOGGER.info('Deleted %d messages from SQS', stream_alert_sqs.deleted_messages)
def __init__(self, *rule_paths): RulesEngine._config = RulesEngine._config or load_config() RulesEngine._threat_intel = ( RulesEngine._threat_intel or ThreatIntel.load_from_config(self.config) ) # Instantiate the alert forwarder to handle sending alerts to the alert processor RulesEngine._alert_forwarder = RulesEngine._alert_forwarder or AlertForwarder() # Load the lookup tables, which include logic for refreshing the tables RulesEngine._lookup_tables = LookupTables.load_lookup_tables(self.config) # If not rule import paths are specified, default to the config if not rule_paths: rule_paths = [item for location in {'rule_locations', 'matcher_locations'} for item in self.config['global']['general'][location]] import_folders(*rule_paths) self._in_lambda = 'LAMBDA_RUNTIME_DIR' in env self._required_outputs_set = resources.get_required_outputs() self._load_rule_table(self.config)
def __init__(self, config_path=DEFAULT_CONFIG_PATH): self.config_path = config_path self.config = config.load_config(config_path)