def __init__(self, *rule_paths): RulesEngine._config = RulesEngine._config or load_config() RulesEngine._threat_intel = ( RulesEngine._threat_intel or ThreatIntel.load_from_config(self.config) ) # Instantiate the alert forwarder to handle sending alerts to the alert processor RulesEngine._alert_forwarder = RulesEngine._alert_forwarder or AlertForwarder() # Load the lookup tables RulesEngine._lookup_tables = LookupTables.get_instance(config=self.config) # If no rule import paths are specified, default to the config rule_paths = rule_paths or [ item for location in {'rule_locations', 'matcher_locations'} for item in self.config['global']['general'][location] ] import_folders(*rule_paths) self._rule_stat_tracker = RuleStatisticTracker( 'STREAMALERT_TRACK_RULE_STATS' in env, 'LAMBDA_RUNTIME_DIR' in env ) self._required_outputs_set = resources.get_required_outputs() self._load_rule_table(self.config)
def test_load_exclude(): """Shared - Config Loading - Exclude""" config = load_config(exclude={'global.json', 'logs.json'}) expected_keys = { 'clusters', 'lambda', 'outputs', 'threat_intel', 'normalized_types' } assert_equal(set(config), expected_keys)
def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partitioner_config'] self._file_format = get_data_file_format(config) if self._file_format == 'parquet': self._alerts_regex = self.ALERTS_REGEX_PARQUET self._data_regex = self.DATA_REGEX_PARQUET elif self._file_format == 'json': self._alerts_regex = self.ALERTS_REGEX self._data_regex = self.DATA_REGEX else: message = ( 'file format "{}" is not supported. Supported file format are ' '"parquet", "json". Please update the setting in athena_partitioner_config ' 'in "conf/lambda.json"'.format(self._file_format)) raise ConfigError(message) self._athena_buckets = athena_partition_buckets(config) db_name = get_database_name(config) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}-streamalert-athena-results'.format(prefix)) self._s3_buckets_and_keys = defaultdict(set) self._create_client(db_name, results_bucket)
def __init__(self, config=None): self._configuration = {} if config is None: config = load_config() self._load_canonical_configurations(config)
def setUp(self): cli_config = CLIConfig(config_path='tests/unit/conf') with patch('streamalert.rules_engine.rules_engine.load_config', Mock(return_value=load_config(self.TEST_CONFIG_PATH))): self.runner = TestRunner(MagicMock(), cli_config) self.setUpPyfakefs()
def test_load_all(): """Shared - Config Loading - All""" config = load_config() expected_keys = { 'clusters', 'global', 'lambda', 'logs', 'outputs', 'threat_intel', 'normalized_types' } assert_equal(set(config), expected_keys)
def __init__(self, config_path, extra_terraform_files=None, build_directory=None): self.config_path = config_path self.config = config.load_config(config_path) self._terraform_files = extra_terraform_files or [] self.build_directory = self._setup_build_directory(build_directory)
def __init__(self): """Initialization logic that can be cached across invocations""" # Merge user-specified output configuration with the required output configuration output_config = load_config(include={'outputs.json'})['outputs'] self.config = resources.merge_required_outputs( output_config, env['STREAMALERT_PREFIX']) self.alerts_table = AlertTable(env['ALERTS_TABLE'])
def test_load_include(): """Shared - Config Loading - Include""" config = load_config(include={'clusters', 'logs.json'}) expected_keys = ['clusters', 'logs'] expected_clusters_keys = ['prod', 'dev'] assert_count_equal(list(config.keys()), expected_keys) assert_count_equal(list(config['clusters'].keys()), expected_clusters_keys)
def setup(self): """Setup the AthenaClient tests""" self._db_name = 'test_database' config = load_config('tests/unit/conf/') prefix = config['global']['account']['prefix'] self.client = AthenaClient( self._db_name, 's3://{}-streamalert-athena-results'.format(prefix), 'unit-test')
def test_load_exclude_schemas(): """Shared - Config Loading - Exclude Clusters""" config = load_config(conf_dir='conf_schemas', exclude={'schemas'}) expected_keys = { 'clusters', 'global', 'lambda', 'outputs', } assert_equal(set(config), expected_keys)
def setup(self): """StatsPublisher - Setup""" # pylint: disable=attribute-defined-outside-init self.publisher = StatsPublisher( config=config.load_config('tests/unit/conf/'), athena_client=None, current_time=datetime(year=2000, month=1, day=1, hour=1, minute=1, second=1))
def __init__(self, artifacts_fh_stream_name): self._dst_firehose_stream_name = artifacts_fh_stream_name self._artifacts = list() ArtifactExtractor._config = ArtifactExtractor._config or config.load_config( validate=True) ArtifactExtractor._firehose_client = ( ArtifactExtractor._firehose_client or FirehoseClient.get_client( prefix=self.config['global']['account']['prefix'], artifact_extractor_config=self.config['global'].get( 'infrastructure', {}).get('artifact_extractor', {})))
def setup(self): """RulePromoter - Setup""" # pylint: disable=attribute-defined-outside-init self.dynamo_mock = mock_dynamodb2() self.dynamo_mock.start() with patch('streamalert.rule_promotion.promoter.load_config') as config_mock, \ patch('streamalert.rule_promotion.promoter.StatsPublisher', Mock()), \ patch('boto3.client', _mock_boto), \ patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}): setup_mock_rules_table(_RULES_TABLE) config_mock.return_value = config.load_config('tests/unit/conf/') self.promoter = RulePromoter() self._add_fake_stats()
def _load_config(function_arn): """Load the Threat Intel Downloader configuration from conf/lambda.json file Returns: (dict): Configuration for Threat Intel Downloader Raises: ConfigError: For invalid or missing configuration files. """ base_config = parse_lambda_arn(function_arn) config = load_config(include={'lambda.json'})['lambda'] base_config.update(config.get('threat_intel_downloader_config', {})) return base_config
def setup(self): """LookupTables - Setup S3 bucket mocking""" self.config = load_config('tests/unit/conf') self.s3_mock = mock_s3() self.s3_mock.start() self.dynamodb_mock = mock_dynamodb2() self.dynamodb_mock.start() self._put_mock_data() self._lookup_tables = LookupTables.get_instance(config=self.config, reset=True)
def setup(self): """LookupTables - Setup S3 bucket mocking""" self.config = load_config('tests/unit/conf') self._dynamodb_mock = mock_dynamodb2() self._dynamodb_mock.start() self._int_driver = construct_persistence_driver( self.config['lookup_tables']['tables']['dinosaur_multi_int']) self._string_driver = construct_persistence_driver( self.config['lookup_tables']['tables']['dinosaur_multi_string']) self._dict_driver = construct_persistence_driver( self.config['lookup_tables']['tables']['dinosaur_multi_dict']) self._put_mock_tables()
def test_process_test_file(self): """StreamAlert CLI - TestRunner Process Test File""" self.fs.create_file( self._DEFAULT_EVENT_PATH, contents=basic_test_file_json( log='unit_test_simple_log', source='unit_test_default_stream', # valid source service='kinesis' # valid service )) self.fs.add_real_directory(self.TEST_CONFIG_PATH) with patch('streamalert.classifier.classifier.config.load_config', Mock(return_value=load_config(self.TEST_CONFIG_PATH))): self.runner._process_test_file(self._DEFAULT_EVENT_PATH) # The CLUSTER env var should be properly deduced and set now assert_equal(os.environ['CLUSTER'], 'test')
def setup(self): """LookupTables - Setup S3 bucket mocking""" self.config = load_config('tests/unit/conf') self._dynamodb_mock = mock_dynamodb2() self._dynamodb_mock.start() self._driver = construct_persistence_driver( self.config['lookup_tables']['tables']['dinosaur']) self._bad_driver = construct_persistence_driver({ 'driver': 'dynamodb', 'table': 'table???', 'partition_key': '??', 'value_key': '?zlaerf', }) self._put_mock_tables()
def setup(self): """LookupTables - Setup S3 bucket mocking""" self.buckets_info = {'bucket_name': ['foo.json', 'bar.json']} self.config = load_config('tests/unit/conf') self.s3_mock = mock_s3() self.s3_mock.start() self._foo_driver = construct_persistence_driver( self.config['lookup_tables']['tables']['foo']) self._bar_driver = construct_persistence_driver( self.config['lookup_tables']['tables']['bar']) self._bad_driver = construct_persistence_driver({ 'driver': 's3', 'bucket': 'bucket_name', 'key': 'invalid-key', }) self._put_mock_tables()
def __init__(self): config = load_config(include={'lambda.json', 'global.json'}) prefix = config['global']['account']['prefix'] athena_config = config['lambda']['athena_partition_refresh_config'] self._athena_buckets = self.buckets_from_config(config) db_name = athena_config.get('database_name', self.STREAMALERT_DATABASE.format(prefix)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}-streamalert-athena-results'.format(prefix)) self._s3_buckets_and_keys = defaultdict(set) self._create_client(db_name, results_bucket)
def __init__(self): # Create some objects to be cached if they have not already been created Classifier._config = Classifier._config or config.load_config(validate=True) Classifier._firehose_client = ( Classifier._firehose_client or FirehoseClient.load_from_config( prefix=self.config['global']['account']['prefix'], firehose_config=self.config['global'].get('infrastructure', {}).get('firehose', {}), log_sources=self.config['logs'] ) ) Classifier._sqs_client = Classifier._sqs_client or SQSClient() # Setup the normalization logic Normalizer.load_from_config(self.config) self._cluster = os.environ['CLUSTER'] self._payloads = [] self._failed_record_count = 0 self._processed_size = 0
def __init__(self): self._config = load_config() prefix = self._config['global']['account']['prefix'] # Create the rule table class for getting staging information self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix)) athena_config = self._config['lambda']['athena_partitioner_config'] # Get the name of the athena database to access db_name = athena_config.get('database_name', get_database_name(self._config)) # Get the S3 bucket to store Athena query results results_bucket = athena_config.get( 'results_bucket', 's3://{}-streamalert-athena-results'.format(prefix)) self._athena_client = AthenaClient(db_name, results_bucket, self.ATHENA_S3_PREFIX) self._current_time = datetime.utcnow() self._staging_stats = dict()
def get_instance(cls, config=None, reset=False): """ Returns a singleton instance of LookupTablesCore. Params: config (dict) OPTIONAL: You can provide this to override default behavior or as an optimization. Be careful; once loaded the LookupTables is cached statically and future invocations will ignore this config parameter, even when provided. reset (bool) OPTIONAL: Flag designating whether or not the cached instance of LookupTablesCore should be re-instantiated. Default value is False. Returns: LookupTablesCore """ if not cls._instance or reset: if config is None: config = load_config() cls._instance = LookupTablesCore(config) cls._instance.setup_tables() return cls._instance
class TestThreatStream: """Test class to test ThreatStream functionalities""" # pylint: disable=protected-access @patch('streamalert.threat_intel_downloader.main.load_config', Mock(return_value=load_config('tests/unit/conf/'))) def setup(self): """Setup TestThreatStream""" # pylint: disable=attribute-defined-outside-init context = get_mock_lambda_context('prefix_threat_intel_downloader', 100000) self.threatstream = ThreatStream(context.invoked_function_arn, context.get_remaining_time_in_millis) @staticmethod def _get_fake_intel(value, source): return { 'value': value, 'itype': 'c2_domain', 'source': source, 'type': 'domain', 'expiration_ts': '2017-11-30T00:01:02.123Z', 'key1': 'value1', 'key2': 'value2' } @staticmethod def _get_http_response(next_url=None): return { 'key1': 'value1', 'objects': [ TestThreatStream._get_fake_intel('malicious_domain.com', 'ioc_source'), TestThreatStream._get_fake_intel('malicious_domain2.com', 'test_source') ], 'meta': { 'next': next_url, 'offset': 100 } } @patch('streamalert.threat_intel_downloader.main.load_config', Mock(return_value=load_config('tests/unit/conf/'))) def test_load_config(self): """ThreatStream - Load Config""" arn = 'arn:aws:lambda:region:123456789012:function:name:development' expected_config = { 'account_id': '123456789012', 'function_name': 'name', 'qualifier': 'development', 'region': 'region', 'enabled': True, 'excluded_sub_types': ['bot_ip', 'brute_ip', 'scan_ip', 'spam_ip', 'tor_ip'], 'ioc_filters': ['crowdstrike', '@airbnb.com'], 'ioc_keys': ['expiration_ts', 'itype', 'source', 'type', 'value'], 'ioc_types': ['domain', 'ip', 'md5'], 'memory': '128', 'timeout': '60' } assert_equal(self.threatstream._load_config(arn), expected_config) def test_process_data(self): """ThreatStream - Process Raw IOC Data""" raw_data = [ self._get_fake_intel('malicious_domain.com', 'ioc_source'), self._get_fake_intel('malicious_domain2.com', 'ioc_source2'), # this will get filtered out self._get_fake_intel('malicious_domain3.com', 'bad_source_ioc'), ] self.threatstream._config['ioc_filters'] = {'ioc_source'} processed_data = self.threatstream._process_data(raw_data) expected_result = [{ 'value': 'malicious_domain.com', 'itype': 'c2_domain', 'source': 'ioc_source', 'type': 'domain', 'expiration_ts': 1512000062 }, { 'value': 'malicious_domain2.com', 'itype': 'c2_domain', 'source': 'ioc_source2', 'type': 'domain', 'expiration_ts': 1512000062 }] assert_equal(processed_data, expected_result) @mock_ssm @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) def test_load_api_creds(self): """ThreatStream - Load API creds from SSM""" value = {'api_user': '******', 'api_key': 'test_key'} put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value) self.threatstream._load_api_creds() assert_equal(self.threatstream.api_user, 'test_user') assert_equal(self.threatstream.api_key, 'test_key') @mock_ssm @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) def test_load_api_creds_cached(self): """ThreatStream - Load API creds from SSM, Cached""" value = {'api_user': '******', 'api_key': 'test_key'} put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value) self.threatstream._load_api_creds() assert_equal(self.threatstream.api_user, 'test_user') assert_equal(self.threatstream.api_key, 'test_key') self.threatstream._load_api_creds() @mock_ssm @raises(ClientError) def test_load_api_creds_client_errors(self): """ThreatStream - Load API creds from SSM, ClientError""" self.threatstream._load_api_creds() @patch('boto3.client') @raises(ThreatStreamCredsError) def test_load_api_creds_empty_response(self, boto_mock): """ThreatStream - Load API creds from SSM, Empty Response""" boto_mock.return_value.get_parameter.return_value = None self.threatstream._load_api_creds() @mock_ssm @raises(ThreatStreamCredsError) @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) def test_load_api_creds_invalid_json(self): """ThreatStream - Load API creds from SSM with invalid JSON""" boto3.client('ssm').put_parameter( Name=ThreatStream.CRED_PARAMETER_NAME, Value='invalid_value', Type='SecureString', Overwrite=True) self.threatstream._load_api_creds() @mock_ssm @raises(ThreatStreamCredsError) @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) def test_load_api_creds_no_api_key(self): """ThreatStream - Load API creds from SSM, No API Key""" value = {'api_user': '******', 'api_key': ''} put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value) self.threatstream._load_api_creds() @patch('streamalert.threat_intel_downloader.main.datetime') def test_epoch_now(self, date_mock): """ThreatStream - Epoch, Now""" fake_date_now = datetime(year=2017, month=9, day=1) date_mock.utcnow.return_value = fake_date_now date_mock.utcfromtimestamp = datetime.utcfromtimestamp expected_value = datetime(year=2017, month=11, day=30) value = self.threatstream._epoch_time(None) assert_equal(datetime.utcfromtimestamp(value), expected_value) def test_epoch_from_time(self): """ThreatStream - Epoch, From Timestamp""" expected_value = datetime(year=2017, month=11, day=30) value = self.threatstream._epoch_time('2017-11-30T00:00:00.000Z') assert_equal(datetime.utcfromtimestamp(value), expected_value) @raises(ValueError) def test_epoch_from_bad_time(self): """ThreatStream - Epoch, Error""" self.threatstream._epoch_time('20171130T00:00:00.000Z') def test_excluded_sub_types(self): """ThreatStream - Excluded Sub Types Property""" expected_value = ['bot_ip', 'brute_ip', 'scan_ip', 'spam_ip', 'tor_ip'] assert_equal(self.threatstream.excluded_sub_types, expected_value) def test_ioc_keys(self): """ThreatStream - IOC Keys Property""" expected_value = ['expiration_ts', 'itype', 'source', 'type', 'value'] assert_equal(self.threatstream.ioc_keys, expected_value) def test_ioc_sources(self): """ThreatStream - IOC Sources Property""" expected_value = ['crowdstrike', '@airbnb.com'] assert_equal(self.threatstream.ioc_sources, expected_value) def test_ioc_types(self): """ThreatStream - IOC Types Property""" expected_value = ['domain', 'ip', 'md5'] assert_equal(self.threatstream.ioc_types, expected_value) def test_threshold(self): """ThreatStream - Threshold Property""" assert_equal(self.threatstream.threshold, 499000) @patch('streamalert.threat_intel_downloader.main.ThreatStream._finalize') @patch('streamalert.threat_intel_downloader.main.requests.get') def test_connect(self, get_mock, finalize_mock): """ThreatStream - Connection to ThreatStream.com""" get_mock.return_value.json.return_value = self._get_http_response() get_mock.return_value.status_code = 200 self.threatstream._config['ioc_filters'] = {'test_source'} self.threatstream._connect('previous_url') expected_intel = [{ 'value': 'malicious_domain2.com', 'itype': 'c2_domain', 'source': 'test_source', 'type': 'domain', 'expiration_ts': 1512000062 }] finalize_mock.assert_called_with(expected_intel, None) @patch('streamalert.threat_intel_downloader.main.ThreatStream._finalize') @patch('streamalert.threat_intel_downloader.main.requests.get') def test_connect_with_next(self, get_mock, finalize_mock): """ThreatStream - Connection to ThreatStream.com, with Continuation""" next_url = 'this_url' get_mock.return_value.json.return_value = self._get_http_response( next_url) get_mock.return_value.status_code = 200 self.threatstream._config['ioc_filters'] = {'test_source'} self.threatstream._connect('previous_url') expected_intel = [{ 'value': 'malicious_domain2.com', 'itype': 'c2_domain', 'source': 'test_source', 'type': 'domain', 'expiration_ts': 1512000062 }] finalize_mock.assert_called_with(expected_intel, next_url) @raises(ThreatStreamRequestsError) @patch('streamalert.threat_intel_downloader.main.requests.get') def test_connect_with_unauthed(self, get_mock): """ThreatStream - Connection to ThreatStream.com, Unauthorized Error""" get_mock.return_value.json.return_value = self._get_http_response() get_mock.return_value.status_code = 401 self.threatstream._connect('previous_url') @raises(ThreatStreamRequestsError) @patch('streamalert.threat_intel_downloader.main.requests.get') def test_connect_with_retry_error(self, get_mock): """ThreatStream - Connection to ThreatStream.com, Retry Error""" get_mock.return_value.status_code = 500 self.threatstream._connect('previous_url') @raises(ThreatStreamRequestsError) @patch('streamalert.threat_intel_downloader.main.requests.get') def test_connect_with_unknown_error(self, get_mock): """ThreatStream - Connection to ThreatStream.com, Unknown Error""" get_mock.return_value.status_code = 404 self.threatstream._connect('previous_url') @patch( 'streamalert.threat_intel_downloader.main.ThreatStream._load_api_creds' ) @patch('streamalert.threat_intel_downloader.main.ThreatStream._connect') def test_runner(self, connect_mock, _): """ThreatStream - Runner""" expected_url = ( '/api/v2/intelligence/?username=user&api_key=key&limit=1000&q=' '(status="active")+AND+(type="domain"+OR+type="ip"+OR+type="md5")+' 'AND+NOT+(itype="bot_ip"+OR+itype="brute_ip"+OR+itype="scan_ip"+' 'OR+itype="spam_ip"+OR+itype="tor_ip")') self.threatstream.api_key = 'key' self.threatstream.api_user = '******' self.threatstream.runner({'none': 'test'}) connect_mock.assert_called_with(expected_url) @patch( 'streamalert.threat_intel_downloader.main.ThreatStream._write_to_dynamodb_table' ) @patch( 'streamalert.threat_intel_downloader.main.ThreatStream._invoke_lambda_function' ) def test_finalize(self, invoke_mock, write_mock): """ThreatStream - Finalize with Intel""" intel = ['foo', 'bar'] self.threatstream._finalize(intel, None) write_mock.assert_called_with(intel) invoke_mock.assert_not_called() @patch( 'streamalert.threat_intel_downloader.main.ThreatStream._write_to_dynamodb_table' ) @patch( 'streamalert.threat_intel_downloader.main.ThreatStream._invoke_lambda_function' ) def test_finalize_next_url(self, invoke_mock, write_mock): """ThreatStream - Finalize with Next URL""" intel = ['foo', 'bar'] self.threatstream._finalize(intel, 'next') write_mock.assert_called_with(intel) invoke_mock.assert_called_with('next') @patch('boto3.resource') def test_write_to_dynamodb_table(self, boto_mock): """ThreatStream - Write Intel to DynamoDB Table""" intel = [self._get_fake_intel('malicious_domain.com', 'test_source')] expected_intel = { 'expiration_ts': '2017-11-30T00:01:02.123Z', 'source': 'test_source', 'ioc_type': 'domain', 'sub_type': 'c2_domain', 'ioc_value': 'malicious_domain.com' } self.threatstream._write_to_dynamodb_table(intel) batch_writer = boto_mock.return_value.Table.return_value.batch_writer.return_value batch_writer.__enter__.return_value.put_item.assert_called_with( Item=expected_intel) @patch('boto3.resource') @raises(ClientError) def test_write_to_dynamodb_table_error(self, boto_mock): """ThreatStream - Write Intel to DynamoDB Table, Error""" intel = [self._get_fake_intel('malicious_domain.com', 'test_source')] err = ClientError({'Error': {'Code': 404}}, 'PutItem') batch_writer = boto_mock.return_value.Table.return_value.batch_writer.return_value batch_writer.__enter__.return_value.put_item.side_effect = err self.threatstream._write_to_dynamodb_table(intel) @patch('boto3.client') def test_invoke_lambda_function(self, boto_mock): """ThreatStream - Invoke Lambda Function""" boto_mock.return_value = MockLambdaClient() self.threatstream._invoke_lambda_function('next_token') boto_mock.assert_called_once() @patch('boto3.client', Mock(return_value=MockLambdaClient())) @raises(ThreatStreamLambdaInvokeError) def test_invoke_lambda_function_error(self): """ThreatStream - Invoke Lambda Function, Error""" MockLambdaClient._raise_exception = True self.threatstream._invoke_lambda_function('next_token')
def import_publishers(cls): if not cls._is_imported: config = load_config() import_folders( *config['global']['general'].get('publisher_locations', [])) cls._is_imported = True
class TestAthenaPartitioner: """Test class for AthenaPartitioner when output data in Parquet format""" @patch('streamalert.athena_partitioner.main.load_config', Mock(return_value=load_config('tests/unit/conf/'))) @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) @patch('streamalert.shared.athena.boto3') def setup(self, boto_patch): """Setup the AthenaPartitioner tests""" boto_patch.client.return_value = MockAthenaClient() self._partitioner = AthenaPartitioner() def test_add_partitions(self): """AthenaPartitioner - Add Partitions""" self._partitioner._s3_buckets_and_keys = { 'unit-test-streamalerts': { b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.parquet', b'parquet/alerts/dt=2020-02-13-08/prefix_streamalert_alert_delivery-01-abcd.parquet' }, 'unit-test-streamalert-data': { b'log_type_1/2017/08/26/14/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/14/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/15/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/16/test-data-11111-22222-33333.snappy', b'log_type_3/2017/08/26/14/test-data-11111-22222-33333.snappy', b'log_type_1/2017/08/26/11/test-data-11111-22222-33333.snappy' }, 'test-bucket-with-data': { b'dt=2020-02-12-05/log_type_1_01234.parquet', b'dt=2020-02-12-06/log_type_1_abcd.parquet', b'dt=2020-02-12-06/log_type_2_0123.parquet', b'dt=2020-02-12-07/log_type_2_abcd.parquet' } } result = self._partitioner._add_partitions() assert_true(result) @patch('logging.Logger.warning') def test_add_partitions_none(self, log_mock): """AthenaPartitioner - Add Partitions, None to Add""" result = self._partitioner._add_partitions() log_mock.assert_called_with('No partitions to add') assert_equal(result, False) def test_get_partitions_from_keys_parquet(self): """AthenaPartitioner - Get Partitions From Keys in parquet format""" expected_result = { 'alerts': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalerts/' 'parquet/alerts/dt=2017-08-26-14\''), '(dt = \'2017-08-27-14\')': ('\'s3://unit-test-streamalerts/' 'parquet/alerts/dt=2017-08-27-14\''), '(dt = \'2017-08-26-15\')': ('\'s3://unit-test-streamalerts/' 'parquet/alerts/dt=2017-08-26-15\'') }, 'log_type_1': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'parquet/log_type_1/dt=2017-08-26-14\'') }, 'log_type_2': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'parquet/log_type_2/dt=2017-08-26-14\''), '(dt = \'2017-08-26-15\')': ('\'s3://unit-test-streamalert-data/' 'parquet/log_type_2/dt=2017-08-26-15\''), '(dt = \'2017-08-26-16\')': ('\'s3://unit-test-streamalert-data/' 'parquet/log_type_2/dt=2017-08-26-16\''), }, 'log_type_3': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'parquet/log_type_3/dt=2017-08-26-14\''), } } self._partitioner._s3_buckets_and_keys = { 'unit-test-streamalerts': { b'parquet/alerts/dt=2017-08-26-14/rule_name_alerts-1304134918401.parquet', b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.parquet', b'parquet/alerts/dt=2017-08-26-15/rule_name_alerts-1304134918401.parquet' }, 'unit-test-streamalert-data': { b'parquet/log_type_1/dt=2017-08-26-14/test-data-11111-22222-33333.snappy', b'parquet/log_type_2/dt=2017-08-26-14/test-data-11111-22222-33333.snappy', b'parquet/log_type_2/dt=2017-08-26-14/test-data-11111-22222-33334.snappy', b'parquet/log_type_2/dt=2017-08-26-15/test-data-11111-22222-33333.snappy', b'parquet/log_type_2/dt=2017-08-26-16/test-data-11111-22222-33333.snappy', b'parquet/log_type_3/dt=2017-08-26-14/test-data-11111-22222-33333.snappy', }, 'test-bucket-with-data': { b'dt=2017-08-26-14/rule_name_alerts-1304134918401.parquet', b'dt=2017-07-30-14/rule_name_alerts-1304134918401.parquet' } } result = self._partitioner._get_partitions_from_keys() assert_equal(result, expected_result) @patch('logging.Logger.warning') def test_get_partitions_from_keys_error(self, log_mock): """AthenaPartitioner - Get Partitions From Keys, Bad Key""" bad_key = b'bad_match_string' self._partitioner._s3_buckets_and_keys = { 'unit-test-streamalerts': {bad_key} } result = self._partitioner._get_partitions_from_keys() log_mock.assert_called_with( 'The key %s does not match any regex, skipping', bad_key.decode('utf-8')) assert_equal(result, dict()) @staticmethod def _s3_record(count): return { 'Records': [{ 's3': { 'bucket': { 'name': 'unit-test-streamalerts' }, 'object': { 'key': ('parquet/alerts/dt=2017-08-{:02d}-' '14/02/test.json'.format(val + 1)) } } } for val in range(count)] } @staticmethod def _s3_record_placeholder_file(): return { 'Records': [{ 's3': { 'bucket': { 'name': 'unit-test-streamalerts' }, 'object': { 'key': 'parquet/alerts/dt=2017-08-01-14/02/test.json_$folder$' } } }] } @staticmethod def _create_test_message(count=2, placeholder=False): """Helper function for creating an sqs messsage body""" if placeholder: body = json.dumps( TestAthenaPartitioner._s3_record_placeholder_file()) else: count = min(count, 30) body = json.dumps(TestAthenaPartitioner._s3_record(count)) return { 'Records': [{ 'body': body, 'messageId': "40d4fac0-64a1-4a20-8be4-893c51aebca1", "attributes": { "SentTimestamp": "1534284301036" } }] } @patch('logging.Logger.debug') @patch( 'streamalert.athena_partitioner.main.AthenaPartitioner._add_partitions' ) def test_run(self, add_mock, log_mock): """AthenaPartitioner - Run""" add_mock.return_value = True self._partitioner.run(self._create_test_message(1)) log_mock.assert_called_with( 'Received notification for object \'%s\' in bucket \'%s\'', 'parquet/alerts/dt=2017-08-01-14/02/test.json'.encode(), 'unit-test-streamalerts') @patch('logging.Logger.info') def test_run_placeholder_file(self, log_mock): """AthenaPartitioner - Run, Placeholder File""" self._partitioner.run(self._create_test_message(1, True)) log_mock.assert_has_calls([ call('Skipping placeholder file notification with key: %s', b'parquet/alerts/dt=2017-08-01-14/02/test.json_$folder$') ]) @patch('logging.Logger.warning') def test_run_no_messages(self, log_mock): """AthenaPartitioner - Run, No Messages""" self._partitioner.run(self._create_test_message(0)) log_mock.assert_called_with('No partitions to add') @patch('logging.Logger.error') def test_run_invalid_bucket(self, log_mock): """AthenaPartitioner - Run, Bad Bucket Name""" event = self._create_test_message(0) bucket = 'bad.bucket.name' s3_record = self._s3_record(1) s3_record['Records'][0]['s3']['bucket']['name'] = bucket event['Records'][0]['body'] = json.dumps(s3_record) self._partitioner.run(event) log_mock.assert_called_with( '\'%s\' not found in \'buckets\' config. Please add this ' 'bucket to enable additions of Hive partitions.', bucket)
class TestAthenaPartitionerJSON: """Test class for AthenaPartitioner when output data in JSON format""" @patch('streamalert.athena_partitioner.main.load_config', Mock(return_value=load_config('tests/unit/conf_athena/'))) @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}) @patch('streamalert.shared.athena.boto3') def setup(self, boto_patch): """Setup the AthenaPartitioner tests""" boto_patch.client.return_value = MockAthenaClient() self._partitioner = AthenaPartitioner() def test_get_partitions_from_keys_json(self): """AthenaPartitioner - Get Partitions From Keys in json format""" expected_result = { 'alerts': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalerts/' 'parquet/alerts/dt=2017-08-26-14\''), '(dt = \'2017-08-27-14\')': ('\'s3://unit-test-streamalerts/' 'parquet/alerts/dt=2017-08-27-14\''), '(dt = \'2017-08-26-15\')': ('\'s3://unit-test-streamalerts/' 'alerts/2017/08/26/15\'') }, 'log_type_1': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'log_type_1/2017/08/26/14\'') }, 'log_type_2': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'log_type_2/2017/08/26/14\''), '(dt = \'2017-08-26-15\')': ('\'s3://unit-test-streamalert-data/' 'log_type_2/2017/08/26/15\''), '(dt = \'2017-08-26-16\')': ('\'s3://unit-test-streamalert-data/' 'log_type_2/2017/08/26/16\''), }, 'log_type_3': { '(dt = \'2017-08-26-14\')': ('\'s3://unit-test-streamalert-data/' 'log_type_3/2017/08/26/14\''), } } self._partitioner._s3_buckets_and_keys = { 'unit-test-streamalerts': { b'parquet/alerts/dt=2017-08-26-14/rule_name_alerts-1304134918401.json', b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.json', b'alerts/2017/08/26/15/rule_name_alerts-1304134918401.json' }, 'unit-test-streamalert-data': { b'log_type_1/2017/08/26/14/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/14/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/14/test-data-11111-22222-33334.snappy', b'log_type_2/2017/08/26/15/test-data-11111-22222-33333.snappy', b'log_type_2/2017/08/26/16/test-data-11111-22222-33333.snappy', b'log_type_3/2017/08/26/14/test-data-11111-22222-33333.snappy', }, 'test-bucket-with-data': { b'2017/08/26/14/rule_name_alerts-1304134918401.json', b'2017/07/30/14/rule_name_alerts-1304134918401.json' } } result = self._partitioner._get_partitions_from_keys() assert_equal(result, expected_result)
def test_load_schemas(): """Shared - Config Loading - Schemas""" # Load from separate dir where logs.json doesn't exist config = load_config(conf_dir='conf_schemas') basic_config = basic_streamalert_config() assert_equal(config['logs'], basic_config['logs'])
def test_load_schemas_logs(): """Shared - Config Loading - Schemas and Logs.json Exist""" # Check if data was loaded from conf/logs.json or the schemas dir if both exist config = load_config(conf_dir='conf') # Logs.json is preferred over schemas for backwards compatibility. assert_equal(config['logs'], {})