예제 #1
0
    def __init__(self, *rule_paths):
        RulesEngine._config = RulesEngine._config or load_config()
        RulesEngine._threat_intel = (
            RulesEngine._threat_intel or ThreatIntel.load_from_config(self.config)
        )
        # Instantiate the alert forwarder to handle sending alerts to the alert processor
        RulesEngine._alert_forwarder = RulesEngine._alert_forwarder or AlertForwarder()

        # Load the lookup tables
        RulesEngine._lookup_tables = LookupTables.get_instance(config=self.config)

        # If no rule import paths are specified, default to the config
        rule_paths = rule_paths or [
            item for location in {'rule_locations', 'matcher_locations'}
            for item in self.config['global']['general'][location]
        ]

        import_folders(*rule_paths)

        self._rule_stat_tracker = RuleStatisticTracker(
            'STREAMALERT_TRACK_RULE_STATS' in env,
            'LAMBDA_RUNTIME_DIR' in env
        )
        self._required_outputs_set = resources.get_required_outputs()
        self._load_rule_table(self.config)
예제 #2
0
 def test_load_exclude():
     """Shared - Config Loading - Exclude"""
     config = load_config(exclude={'global.json', 'logs.json'})
     expected_keys = {
         'clusters', 'lambda', 'outputs', 'threat_intel', 'normalized_types'
     }
     assert_equal(set(config), expected_keys)
예제 #3
0
    def __init__(self):
        config = load_config(include={'lambda.json', 'global.json'})
        prefix = config['global']['account']['prefix']
        athena_config = config['lambda']['athena_partitioner_config']
        self._file_format = get_data_file_format(config)

        if self._file_format == 'parquet':
            self._alerts_regex = self.ALERTS_REGEX_PARQUET
            self._data_regex = self.DATA_REGEX_PARQUET

        elif self._file_format == 'json':
            self._alerts_regex = self.ALERTS_REGEX
            self._data_regex = self.DATA_REGEX
        else:
            message = (
                'file format "{}" is not supported. Supported file format are '
                '"parquet", "json". Please update the setting in athena_partitioner_config '
                'in "conf/lambda.json"'.format(self._file_format))
            raise ConfigError(message)

        self._athena_buckets = athena_partition_buckets(config)

        db_name = get_database_name(config)

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}-streamalert-athena-results'.format(prefix))

        self._s3_buckets_and_keys = defaultdict(set)

        self._create_client(db_name, results_bucket)
예제 #4
0
    def __init__(self, config=None):
        self._configuration = {}

        if config is None:
            config = load_config()

        self._load_canonical_configurations(config)
예제 #5
0
    def setUp(self):
        cli_config = CLIConfig(config_path='tests/unit/conf')
        with patch('streamalert.rules_engine.rules_engine.load_config',
                   Mock(return_value=load_config(self.TEST_CONFIG_PATH))):
            self.runner = TestRunner(MagicMock(), cli_config)

        self.setUpPyfakefs()
예제 #6
0
 def test_load_all():
     """Shared - Config Loading - All"""
     config = load_config()
     expected_keys = {
         'clusters', 'global', 'lambda', 'logs', 'outputs', 'threat_intel',
         'normalized_types'
     }
     assert_equal(set(config), expected_keys)
예제 #7
0
 def __init__(self,
              config_path,
              extra_terraform_files=None,
              build_directory=None):
     self.config_path = config_path
     self.config = config.load_config(config_path)
     self._terraform_files = extra_terraform_files or []
     self.build_directory = self._setup_build_directory(build_directory)
예제 #8
0
    def __init__(self):
        """Initialization logic that can be cached across invocations"""
        # Merge user-specified output configuration with the required output configuration
        output_config = load_config(include={'outputs.json'})['outputs']
        self.config = resources.merge_required_outputs(
            output_config, env['STREAMALERT_PREFIX'])

        self.alerts_table = AlertTable(env['ALERTS_TABLE'])
예제 #9
0
 def test_load_include():
     """Shared - Config Loading - Include"""
     config = load_config(include={'clusters', 'logs.json'})
     expected_keys = ['clusters', 'logs']
     expected_clusters_keys = ['prod', 'dev']
     assert_count_equal(list(config.keys()), expected_keys)
     assert_count_equal(list(config['clusters'].keys()),
                        expected_clusters_keys)
예제 #10
0
    def setup(self):
        """Setup the AthenaClient tests"""

        self._db_name = 'test_database'
        config = load_config('tests/unit/conf/')
        prefix = config['global']['account']['prefix']

        self.client = AthenaClient(
            self._db_name, 's3://{}-streamalert-athena-results'.format(prefix),
            'unit-test')
예제 #11
0
 def test_load_exclude_schemas():
     """Shared - Config Loading - Exclude Clusters"""
     config = load_config(conf_dir='conf_schemas', exclude={'schemas'})
     expected_keys = {
         'clusters',
         'global',
         'lambda',
         'outputs',
     }
     assert_equal(set(config), expected_keys)
예제 #12
0
 def setup(self):
     """StatsPublisher - Setup"""
     # pylint: disable=attribute-defined-outside-init
     self.publisher = StatsPublisher(
         config=config.load_config('tests/unit/conf/'),
         athena_client=None,
         current_time=datetime(year=2000,
                               month=1,
                               day=1,
                               hour=1,
                               minute=1,
                               second=1))
예제 #13
0
    def __init__(self, artifacts_fh_stream_name):
        self._dst_firehose_stream_name = artifacts_fh_stream_name
        self._artifacts = list()

        ArtifactExtractor._config = ArtifactExtractor._config or config.load_config(
            validate=True)

        ArtifactExtractor._firehose_client = (
            ArtifactExtractor._firehose_client or FirehoseClient.get_client(
                prefix=self.config['global']['account']['prefix'],
                artifact_extractor_config=self.config['global'].get(
                    'infrastructure', {}).get('artifact_extractor', {})))
예제 #14
0
 def setup(self):
     """RulePromoter - Setup"""
     # pylint: disable=attribute-defined-outside-init
     self.dynamo_mock = mock_dynamodb2()
     self.dynamo_mock.start()
     with patch('streamalert.rule_promotion.promoter.load_config') as config_mock, \
          patch('streamalert.rule_promotion.promoter.StatsPublisher', Mock()), \
          patch('boto3.client', _mock_boto), \
          patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'}):
         setup_mock_rules_table(_RULES_TABLE)
         config_mock.return_value = config.load_config('tests/unit/conf/')
         self.promoter = RulePromoter()
         self._add_fake_stats()
예제 #15
0
    def _load_config(function_arn):
        """Load the Threat Intel Downloader configuration from conf/lambda.json file

        Returns:
            (dict): Configuration for Threat Intel Downloader

        Raises:
            ConfigError: For invalid or missing configuration files.
        """

        base_config = parse_lambda_arn(function_arn)
        config = load_config(include={'lambda.json'})['lambda']
        base_config.update(config.get('threat_intel_downloader_config', {}))
        return base_config
예제 #16
0
    def setup(self):
        """LookupTables - Setup S3 bucket mocking"""
        self.config = load_config('tests/unit/conf')

        self.s3_mock = mock_s3()
        self.s3_mock.start()

        self.dynamodb_mock = mock_dynamodb2()
        self.dynamodb_mock.start()

        self._put_mock_data()

        self._lookup_tables = LookupTables.get_instance(config=self.config,
                                                        reset=True)
예제 #17
0
    def setup(self):
        """LookupTables - Setup S3 bucket mocking"""
        self.config = load_config('tests/unit/conf')

        self._dynamodb_mock = mock_dynamodb2()
        self._dynamodb_mock.start()

        self._int_driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['dinosaur_multi_int'])
        self._string_driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['dinosaur_multi_string'])
        self._dict_driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['dinosaur_multi_dict'])

        self._put_mock_tables()
예제 #18
0
    def test_process_test_file(self):
        """StreamAlert CLI - TestRunner Process Test File"""
        self.fs.create_file(
            self._DEFAULT_EVENT_PATH,
            contents=basic_test_file_json(
                log='unit_test_simple_log',
                source='unit_test_default_stream',  # valid source
                service='kinesis'  # valid service
            ))
        self.fs.add_real_directory(self.TEST_CONFIG_PATH)
        with patch('streamalert.classifier.classifier.config.load_config',
                   Mock(return_value=load_config(self.TEST_CONFIG_PATH))):
            self.runner._process_test_file(self._DEFAULT_EVENT_PATH)

        # The CLUSTER env var should be properly deduced and set now
        assert_equal(os.environ['CLUSTER'], 'test')
예제 #19
0
    def setup(self):
        """LookupTables - Setup S3 bucket mocking"""
        self.config = load_config('tests/unit/conf')
        self._dynamodb_mock = mock_dynamodb2()
        self._dynamodb_mock.start()

        self._driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['dinosaur'])
        self._bad_driver = construct_persistence_driver({
            'driver': 'dynamodb',
            'table': 'table???',
            'partition_key': '??',
            'value_key': '?zlaerf',
        })

        self._put_mock_tables()
예제 #20
0
    def setup(self):
        """LookupTables - Setup S3 bucket mocking"""
        self.buckets_info = {'bucket_name': ['foo.json', 'bar.json']}
        self.config = load_config('tests/unit/conf')
        self.s3_mock = mock_s3()
        self.s3_mock.start()

        self._foo_driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['foo'])
        self._bar_driver = construct_persistence_driver(
            self.config['lookup_tables']['tables']['bar'])
        self._bad_driver = construct_persistence_driver({
            'driver': 's3',
            'bucket': 'bucket_name',
            'key': 'invalid-key',
        })

        self._put_mock_tables()
예제 #21
0
    def __init__(self):
        config = load_config(include={'lambda.json', 'global.json'})
        prefix = config['global']['account']['prefix']
        athena_config = config['lambda']['athena_partition_refresh_config']

        self._athena_buckets = self.buckets_from_config(config)

        db_name = athena_config.get('database_name',
                                    self.STREAMALERT_DATABASE.format(prefix))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}-streamalert-athena-results'.format(prefix))

        self._s3_buckets_and_keys = defaultdict(set)

        self._create_client(db_name, results_bucket)
예제 #22
0
    def __init__(self):
        # Create some objects to be cached if they have not already been created
        Classifier._config = Classifier._config or config.load_config(validate=True)
        Classifier._firehose_client = (
            Classifier._firehose_client or FirehoseClient.load_from_config(
                prefix=self.config['global']['account']['prefix'],
                firehose_config=self.config['global'].get('infrastructure', {}).get('firehose', {}),
                log_sources=self.config['logs']
            )
        )
        Classifier._sqs_client = Classifier._sqs_client or SQSClient()

        # Setup the normalization logic
        Normalizer.load_from_config(self.config)
        self._cluster = os.environ['CLUSTER']
        self._payloads = []
        self._failed_record_count = 0
        self._processed_size = 0
예제 #23
0
    def __init__(self):
        self._config = load_config()
        prefix = self._config['global']['account']['prefix']

        # Create the rule table class for getting staging information
        self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix))

        athena_config = self._config['lambda']['athena_partitioner_config']

        # Get the name of the athena database to access
        db_name = athena_config.get('database_name',
                                    get_database_name(self._config))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}-streamalert-athena-results'.format(prefix))

        self._athena_client = AthenaClient(db_name, results_bucket,
                                           self.ATHENA_S3_PREFIX)
        self._current_time = datetime.utcnow()
        self._staging_stats = dict()
예제 #24
0
    def get_instance(cls, config=None, reset=False):
        """
        Returns a singleton instance of LookupTablesCore.

        Params:
            config (dict) OPTIONAL: You can provide this to override default behavior or as an
                optimization. Be careful; once loaded the LookupTables is cached statically
                and future invocations will ignore this config parameter, even when provided.
            reset (bool) OPTIONAL: Flag designating whether or not the cached instance of
                LookupTablesCore should be re-instantiated. Default value is False.

        Returns:
            LookupTablesCore
        """
        if not cls._instance or reset:
            if config is None:
                config = load_config()

            cls._instance = LookupTablesCore(config)
            cls._instance.setup_tables()

        return cls._instance
예제 #25
0
class TestThreatStream:
    """Test class to test ThreatStream functionalities"""
    # pylint: disable=protected-access

    @patch('streamalert.threat_intel_downloader.main.load_config',
           Mock(return_value=load_config('tests/unit/conf/')))
    def setup(self):
        """Setup TestThreatStream"""
        # pylint: disable=attribute-defined-outside-init
        context = get_mock_lambda_context('prefix_threat_intel_downloader',
                                          100000)
        self.threatstream = ThreatStream(context.invoked_function_arn,
                                         context.get_remaining_time_in_millis)

    @staticmethod
    def _get_fake_intel(value, source):
        return {
            'value': value,
            'itype': 'c2_domain',
            'source': source,
            'type': 'domain',
            'expiration_ts': '2017-11-30T00:01:02.123Z',
            'key1': 'value1',
            'key2': 'value2'
        }

    @staticmethod
    def _get_http_response(next_url=None):
        return {
            'key1':
            'value1',
            'objects': [
                TestThreatStream._get_fake_intel('malicious_domain.com',
                                                 'ioc_source'),
                TestThreatStream._get_fake_intel('malicious_domain2.com',
                                                 'test_source')
            ],
            'meta': {
                'next': next_url,
                'offset': 100
            }
        }

    @patch('streamalert.threat_intel_downloader.main.load_config',
           Mock(return_value=load_config('tests/unit/conf/')))
    def test_load_config(self):
        """ThreatStream - Load Config"""
        arn = 'arn:aws:lambda:region:123456789012:function:name:development'
        expected_config = {
            'account_id':
            '123456789012',
            'function_name':
            'name',
            'qualifier':
            'development',
            'region':
            'region',
            'enabled':
            True,
            'excluded_sub_types':
            ['bot_ip', 'brute_ip', 'scan_ip', 'spam_ip', 'tor_ip'],
            'ioc_filters': ['crowdstrike', '@airbnb.com'],
            'ioc_keys': ['expiration_ts', 'itype', 'source', 'type', 'value'],
            'ioc_types': ['domain', 'ip', 'md5'],
            'memory':
            '128',
            'timeout':
            '60'
        }
        assert_equal(self.threatstream._load_config(arn), expected_config)

    def test_process_data(self):
        """ThreatStream - Process Raw IOC Data"""
        raw_data = [
            self._get_fake_intel('malicious_domain.com', 'ioc_source'),
            self._get_fake_intel('malicious_domain2.com', 'ioc_source2'),
            # this will get filtered out
            self._get_fake_intel('malicious_domain3.com', 'bad_source_ioc'),
        ]
        self.threatstream._config['ioc_filters'] = {'ioc_source'}
        processed_data = self.threatstream._process_data(raw_data)
        expected_result = [{
            'value': 'malicious_domain.com',
            'itype': 'c2_domain',
            'source': 'ioc_source',
            'type': 'domain',
            'expiration_ts': 1512000062
        }, {
            'value': 'malicious_domain2.com',
            'itype': 'c2_domain',
            'source': 'ioc_source2',
            'type': 'domain',
            'expiration_ts': 1512000062
        }]
        assert_equal(processed_data, expected_result)

    @mock_ssm
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    def test_load_api_creds(self):
        """ThreatStream - Load API creds from SSM"""
        value = {'api_user': '******', 'api_key': 'test_key'}
        put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value)
        self.threatstream._load_api_creds()
        assert_equal(self.threatstream.api_user, 'test_user')
        assert_equal(self.threatstream.api_key, 'test_key')

    @mock_ssm
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    def test_load_api_creds_cached(self):
        """ThreatStream - Load API creds from SSM, Cached"""
        value = {'api_user': '******', 'api_key': 'test_key'}
        put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value)
        self.threatstream._load_api_creds()
        assert_equal(self.threatstream.api_user, 'test_user')
        assert_equal(self.threatstream.api_key, 'test_key')
        self.threatstream._load_api_creds()

    @mock_ssm
    @raises(ClientError)
    def test_load_api_creds_client_errors(self):
        """ThreatStream - Load API creds from SSM, ClientError"""
        self.threatstream._load_api_creds()

    @patch('boto3.client')
    @raises(ThreatStreamCredsError)
    def test_load_api_creds_empty_response(self, boto_mock):
        """ThreatStream - Load API creds from SSM, Empty Response"""
        boto_mock.return_value.get_parameter.return_value = None
        self.threatstream._load_api_creds()

    @mock_ssm
    @raises(ThreatStreamCredsError)
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    def test_load_api_creds_invalid_json(self):
        """ThreatStream - Load API creds from SSM with invalid JSON"""
        boto3.client('ssm').put_parameter(
            Name=ThreatStream.CRED_PARAMETER_NAME,
            Value='invalid_value',
            Type='SecureString',
            Overwrite=True)
        self.threatstream._load_api_creds()

    @mock_ssm
    @raises(ThreatStreamCredsError)
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    def test_load_api_creds_no_api_key(self):
        """ThreatStream - Load API creds from SSM, No API Key"""
        value = {'api_user': '******', 'api_key': ''}
        put_mock_params(ThreatStream.CRED_PARAMETER_NAME, value)
        self.threatstream._load_api_creds()

    @patch('streamalert.threat_intel_downloader.main.datetime')
    def test_epoch_now(self, date_mock):
        """ThreatStream - Epoch, Now"""
        fake_date_now = datetime(year=2017, month=9, day=1)
        date_mock.utcnow.return_value = fake_date_now
        date_mock.utcfromtimestamp = datetime.utcfromtimestamp
        expected_value = datetime(year=2017, month=11, day=30)
        value = self.threatstream._epoch_time(None)
        assert_equal(datetime.utcfromtimestamp(value), expected_value)

    def test_epoch_from_time(self):
        """ThreatStream - Epoch, From Timestamp"""
        expected_value = datetime(year=2017, month=11, day=30)
        value = self.threatstream._epoch_time('2017-11-30T00:00:00.000Z')
        assert_equal(datetime.utcfromtimestamp(value), expected_value)

    @raises(ValueError)
    def test_epoch_from_bad_time(self):
        """ThreatStream - Epoch, Error"""
        self.threatstream._epoch_time('20171130T00:00:00.000Z')

    def test_excluded_sub_types(self):
        """ThreatStream - Excluded Sub Types Property"""
        expected_value = ['bot_ip', 'brute_ip', 'scan_ip', 'spam_ip', 'tor_ip']
        assert_equal(self.threatstream.excluded_sub_types, expected_value)

    def test_ioc_keys(self):
        """ThreatStream - IOC Keys Property"""
        expected_value = ['expiration_ts', 'itype', 'source', 'type', 'value']
        assert_equal(self.threatstream.ioc_keys, expected_value)

    def test_ioc_sources(self):
        """ThreatStream - IOC Sources Property"""
        expected_value = ['crowdstrike', '@airbnb.com']
        assert_equal(self.threatstream.ioc_sources, expected_value)

    def test_ioc_types(self):
        """ThreatStream - IOC Types Property"""
        expected_value = ['domain', 'ip', 'md5']
        assert_equal(self.threatstream.ioc_types, expected_value)

    def test_threshold(self):
        """ThreatStream - Threshold Property"""
        assert_equal(self.threatstream.threshold, 499000)

    @patch('streamalert.threat_intel_downloader.main.ThreatStream._finalize')
    @patch('streamalert.threat_intel_downloader.main.requests.get')
    def test_connect(self, get_mock, finalize_mock):
        """ThreatStream - Connection to ThreatStream.com"""
        get_mock.return_value.json.return_value = self._get_http_response()
        get_mock.return_value.status_code = 200
        self.threatstream._config['ioc_filters'] = {'test_source'}
        self.threatstream._connect('previous_url')
        expected_intel = [{
            'value': 'malicious_domain2.com',
            'itype': 'c2_domain',
            'source': 'test_source',
            'type': 'domain',
            'expiration_ts': 1512000062
        }]
        finalize_mock.assert_called_with(expected_intel, None)

    @patch('streamalert.threat_intel_downloader.main.ThreatStream._finalize')
    @patch('streamalert.threat_intel_downloader.main.requests.get')
    def test_connect_with_next(self, get_mock, finalize_mock):
        """ThreatStream - Connection to ThreatStream.com, with Continuation"""
        next_url = 'this_url'
        get_mock.return_value.json.return_value = self._get_http_response(
            next_url)
        get_mock.return_value.status_code = 200
        self.threatstream._config['ioc_filters'] = {'test_source'}
        self.threatstream._connect('previous_url')
        expected_intel = [{
            'value': 'malicious_domain2.com',
            'itype': 'c2_domain',
            'source': 'test_source',
            'type': 'domain',
            'expiration_ts': 1512000062
        }]
        finalize_mock.assert_called_with(expected_intel, next_url)

    @raises(ThreatStreamRequestsError)
    @patch('streamalert.threat_intel_downloader.main.requests.get')
    def test_connect_with_unauthed(self, get_mock):
        """ThreatStream - Connection to ThreatStream.com, Unauthorized Error"""
        get_mock.return_value.json.return_value = self._get_http_response()
        get_mock.return_value.status_code = 401
        self.threatstream._connect('previous_url')

    @raises(ThreatStreamRequestsError)
    @patch('streamalert.threat_intel_downloader.main.requests.get')
    def test_connect_with_retry_error(self, get_mock):
        """ThreatStream - Connection to ThreatStream.com, Retry Error"""
        get_mock.return_value.status_code = 500
        self.threatstream._connect('previous_url')

    @raises(ThreatStreamRequestsError)
    @patch('streamalert.threat_intel_downloader.main.requests.get')
    def test_connect_with_unknown_error(self, get_mock):
        """ThreatStream - Connection to ThreatStream.com, Unknown Error"""
        get_mock.return_value.status_code = 404
        self.threatstream._connect('previous_url')

    @patch(
        'streamalert.threat_intel_downloader.main.ThreatStream._load_api_creds'
    )
    @patch('streamalert.threat_intel_downloader.main.ThreatStream._connect')
    def test_runner(self, connect_mock, _):
        """ThreatStream - Runner"""
        expected_url = (
            '/api/v2/intelligence/?username=user&api_key=key&limit=1000&q='
            '(status="active")+AND+(type="domain"+OR+type="ip"+OR+type="md5")+'
            'AND+NOT+(itype="bot_ip"+OR+itype="brute_ip"+OR+itype="scan_ip"+'
            'OR+itype="spam_ip"+OR+itype="tor_ip")')
        self.threatstream.api_key = 'key'
        self.threatstream.api_user = '******'
        self.threatstream.runner({'none': 'test'})
        connect_mock.assert_called_with(expected_url)

    @patch(
        'streamalert.threat_intel_downloader.main.ThreatStream._write_to_dynamodb_table'
    )
    @patch(
        'streamalert.threat_intel_downloader.main.ThreatStream._invoke_lambda_function'
    )
    def test_finalize(self, invoke_mock, write_mock):
        """ThreatStream - Finalize with Intel"""
        intel = ['foo', 'bar']
        self.threatstream._finalize(intel, None)
        write_mock.assert_called_with(intel)
        invoke_mock.assert_not_called()

    @patch(
        'streamalert.threat_intel_downloader.main.ThreatStream._write_to_dynamodb_table'
    )
    @patch(
        'streamalert.threat_intel_downloader.main.ThreatStream._invoke_lambda_function'
    )
    def test_finalize_next_url(self, invoke_mock, write_mock):
        """ThreatStream - Finalize with Next URL"""
        intel = ['foo', 'bar']
        self.threatstream._finalize(intel, 'next')
        write_mock.assert_called_with(intel)
        invoke_mock.assert_called_with('next')

    @patch('boto3.resource')
    def test_write_to_dynamodb_table(self, boto_mock):
        """ThreatStream - Write Intel to DynamoDB Table"""
        intel = [self._get_fake_intel('malicious_domain.com', 'test_source')]
        expected_intel = {
            'expiration_ts': '2017-11-30T00:01:02.123Z',
            'source': 'test_source',
            'ioc_type': 'domain',
            'sub_type': 'c2_domain',
            'ioc_value': 'malicious_domain.com'
        }
        self.threatstream._write_to_dynamodb_table(intel)
        batch_writer = boto_mock.return_value.Table.return_value.batch_writer.return_value
        batch_writer.__enter__.return_value.put_item.assert_called_with(
            Item=expected_intel)

    @patch('boto3.resource')
    @raises(ClientError)
    def test_write_to_dynamodb_table_error(self, boto_mock):
        """ThreatStream - Write Intel to DynamoDB Table, Error"""
        intel = [self._get_fake_intel('malicious_domain.com', 'test_source')]
        err = ClientError({'Error': {'Code': 404}}, 'PutItem')
        batch_writer = boto_mock.return_value.Table.return_value.batch_writer.return_value
        batch_writer.__enter__.return_value.put_item.side_effect = err

        self.threatstream._write_to_dynamodb_table(intel)

    @patch('boto3.client')
    def test_invoke_lambda_function(self, boto_mock):
        """ThreatStream - Invoke Lambda Function"""
        boto_mock.return_value = MockLambdaClient()
        self.threatstream._invoke_lambda_function('next_token')
        boto_mock.assert_called_once()

    @patch('boto3.client', Mock(return_value=MockLambdaClient()))
    @raises(ThreatStreamLambdaInvokeError)
    def test_invoke_lambda_function_error(self):
        """ThreatStream - Invoke Lambda Function, Error"""
        MockLambdaClient._raise_exception = True
        self.threatstream._invoke_lambda_function('next_token')
예제 #26
0
 def import_publishers(cls):
     if not cls._is_imported:
         config = load_config()
         import_folders(
             *config['global']['general'].get('publisher_locations', []))
         cls._is_imported = True
예제 #27
0
class TestAthenaPartitioner:
    """Test class for AthenaPartitioner when output data in Parquet format"""
    @patch('streamalert.athena_partitioner.main.load_config',
           Mock(return_value=load_config('tests/unit/conf/')))
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    @patch('streamalert.shared.athena.boto3')
    def setup(self, boto_patch):
        """Setup the AthenaPartitioner tests"""
        boto_patch.client.return_value = MockAthenaClient()
        self._partitioner = AthenaPartitioner()

    def test_add_partitions(self):
        """AthenaPartitioner - Add Partitions"""
        self._partitioner._s3_buckets_and_keys = {
            'unit-test-streamalerts': {
                b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.parquet',
                b'parquet/alerts/dt=2020-02-13-08/prefix_streamalert_alert_delivery-01-abcd.parquet'
            },
            'unit-test-streamalert-data': {
                b'log_type_1/2017/08/26/14/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/14/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/15/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/16/test-data-11111-22222-33333.snappy',
                b'log_type_3/2017/08/26/14/test-data-11111-22222-33333.snappy',
                b'log_type_1/2017/08/26/11/test-data-11111-22222-33333.snappy'
            },
            'test-bucket-with-data': {
                b'dt=2020-02-12-05/log_type_1_01234.parquet',
                b'dt=2020-02-12-06/log_type_1_abcd.parquet',
                b'dt=2020-02-12-06/log_type_2_0123.parquet',
                b'dt=2020-02-12-07/log_type_2_abcd.parquet'
            }
        }
        result = self._partitioner._add_partitions()

        assert_true(result)

    @patch('logging.Logger.warning')
    def test_add_partitions_none(self, log_mock):
        """AthenaPartitioner - Add Partitions, None to Add"""
        result = self._partitioner._add_partitions()
        log_mock.assert_called_with('No partitions to add')
        assert_equal(result, False)

    def test_get_partitions_from_keys_parquet(self):
        """AthenaPartitioner - Get Partitions From Keys in parquet format"""
        expected_result = {
            'alerts': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalerts/'
                 'parquet/alerts/dt=2017-08-26-14\''),
                '(dt = \'2017-08-27-14\')':
                ('\'s3://unit-test-streamalerts/'
                 'parquet/alerts/dt=2017-08-27-14\''),
                '(dt = \'2017-08-26-15\')':
                ('\'s3://unit-test-streamalerts/'
                 'parquet/alerts/dt=2017-08-26-15\'')
            },
            'log_type_1': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'parquet/log_type_1/dt=2017-08-26-14\'')
            },
            'log_type_2': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'parquet/log_type_2/dt=2017-08-26-14\''),
                '(dt = \'2017-08-26-15\')':
                ('\'s3://unit-test-streamalert-data/'
                 'parquet/log_type_2/dt=2017-08-26-15\''),
                '(dt = \'2017-08-26-16\')':
                ('\'s3://unit-test-streamalert-data/'
                 'parquet/log_type_2/dt=2017-08-26-16\''),
            },
            'log_type_3': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'parquet/log_type_3/dt=2017-08-26-14\''),
            }
        }

        self._partitioner._s3_buckets_and_keys = {
            'unit-test-streamalerts': {
                b'parquet/alerts/dt=2017-08-26-14/rule_name_alerts-1304134918401.parquet',
                b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.parquet',
                b'parquet/alerts/dt=2017-08-26-15/rule_name_alerts-1304134918401.parquet'
            },
            'unit-test-streamalert-data': {
                b'parquet/log_type_1/dt=2017-08-26-14/test-data-11111-22222-33333.snappy',
                b'parquet/log_type_2/dt=2017-08-26-14/test-data-11111-22222-33333.snappy',
                b'parquet/log_type_2/dt=2017-08-26-14/test-data-11111-22222-33334.snappy',
                b'parquet/log_type_2/dt=2017-08-26-15/test-data-11111-22222-33333.snappy',
                b'parquet/log_type_2/dt=2017-08-26-16/test-data-11111-22222-33333.snappy',
                b'parquet/log_type_3/dt=2017-08-26-14/test-data-11111-22222-33333.snappy',
            },
            'test-bucket-with-data': {
                b'dt=2017-08-26-14/rule_name_alerts-1304134918401.parquet',
                b'dt=2017-07-30-14/rule_name_alerts-1304134918401.parquet'
            }
        }

        result = self._partitioner._get_partitions_from_keys()

        assert_equal(result, expected_result)

    @patch('logging.Logger.warning')
    def test_get_partitions_from_keys_error(self, log_mock):
        """AthenaPartitioner - Get Partitions From Keys, Bad Key"""
        bad_key = b'bad_match_string'
        self._partitioner._s3_buckets_and_keys = {
            'unit-test-streamalerts': {bad_key}
        }

        result = self._partitioner._get_partitions_from_keys()

        log_mock.assert_called_with(
            'The key %s does not match any regex, skipping',
            bad_key.decode('utf-8'))
        assert_equal(result, dict())

    @staticmethod
    def _s3_record(count):
        return {
            'Records': [{
                's3': {
                    'bucket': {
                        'name': 'unit-test-streamalerts'
                    },
                    'object': {
                        'key': ('parquet/alerts/dt=2017-08-{:02d}-'
                                '14/02/test.json'.format(val + 1))
                    }
                }
            } for val in range(count)]
        }

    @staticmethod
    def _s3_record_placeholder_file():
        return {
            'Records': [{
                's3': {
                    'bucket': {
                        'name': 'unit-test-streamalerts'
                    },
                    'object': {
                        'key':
                        'parquet/alerts/dt=2017-08-01-14/02/test.json_$folder$'
                    }
                }
            }]
        }

    @staticmethod
    def _create_test_message(count=2, placeholder=False):
        """Helper function for creating an sqs messsage body"""
        if placeholder:
            body = json.dumps(
                TestAthenaPartitioner._s3_record_placeholder_file())
        else:
            count = min(count, 30)
            body = json.dumps(TestAthenaPartitioner._s3_record(count))
        return {
            'Records': [{
                'body': body,
                'messageId': "40d4fac0-64a1-4a20-8be4-893c51aebca1",
                "attributes": {
                    "SentTimestamp": "1534284301036"
                }
            }]
        }

    @patch('logging.Logger.debug')
    @patch(
        'streamalert.athena_partitioner.main.AthenaPartitioner._add_partitions'
    )
    def test_run(self, add_mock, log_mock):
        """AthenaPartitioner - Run"""
        add_mock.return_value = True
        self._partitioner.run(self._create_test_message(1))
        log_mock.assert_called_with(
            'Received notification for object \'%s\' in bucket \'%s\'',
            'parquet/alerts/dt=2017-08-01-14/02/test.json'.encode(),
            'unit-test-streamalerts')

    @patch('logging.Logger.info')
    def test_run_placeholder_file(self, log_mock):
        """AthenaPartitioner - Run, Placeholder File"""
        self._partitioner.run(self._create_test_message(1, True))
        log_mock.assert_has_calls([
            call('Skipping placeholder file notification with key: %s',
                 b'parquet/alerts/dt=2017-08-01-14/02/test.json_$folder$')
        ])

    @patch('logging.Logger.warning')
    def test_run_no_messages(self, log_mock):
        """AthenaPartitioner - Run, No Messages"""
        self._partitioner.run(self._create_test_message(0))
        log_mock.assert_called_with('No partitions to add')

    @patch('logging.Logger.error')
    def test_run_invalid_bucket(self, log_mock):
        """AthenaPartitioner - Run, Bad Bucket Name"""
        event = self._create_test_message(0)
        bucket = 'bad.bucket.name'
        s3_record = self._s3_record(1)
        s3_record['Records'][0]['s3']['bucket']['name'] = bucket
        event['Records'][0]['body'] = json.dumps(s3_record)
        self._partitioner.run(event)
        log_mock.assert_called_with(
            '\'%s\' not found in \'buckets\' config. Please add this '
            'bucket to enable additions of Hive partitions.', bucket)
예제 #28
0
class TestAthenaPartitionerJSON:
    """Test class for AthenaPartitioner when output data in JSON format"""
    @patch('streamalert.athena_partitioner.main.load_config',
           Mock(return_value=load_config('tests/unit/conf_athena/')))
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-east-1'})
    @patch('streamalert.shared.athena.boto3')
    def setup(self, boto_patch):
        """Setup the AthenaPartitioner tests"""
        boto_patch.client.return_value = MockAthenaClient()
        self._partitioner = AthenaPartitioner()

    def test_get_partitions_from_keys_json(self):
        """AthenaPartitioner - Get Partitions From Keys in json format"""
        expected_result = {
            'alerts': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalerts/'
                 'parquet/alerts/dt=2017-08-26-14\''),
                '(dt = \'2017-08-27-14\')':
                ('\'s3://unit-test-streamalerts/'
                 'parquet/alerts/dt=2017-08-27-14\''),
                '(dt = \'2017-08-26-15\')': ('\'s3://unit-test-streamalerts/'
                                             'alerts/2017/08/26/15\'')
            },
            'log_type_1': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'log_type_1/2017/08/26/14\'')
            },
            'log_type_2': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'log_type_2/2017/08/26/14\''),
                '(dt = \'2017-08-26-15\')':
                ('\'s3://unit-test-streamalert-data/'
                 'log_type_2/2017/08/26/15\''),
                '(dt = \'2017-08-26-16\')':
                ('\'s3://unit-test-streamalert-data/'
                 'log_type_2/2017/08/26/16\''),
            },
            'log_type_3': {
                '(dt = \'2017-08-26-14\')':
                ('\'s3://unit-test-streamalert-data/'
                 'log_type_3/2017/08/26/14\''),
            }
        }

        self._partitioner._s3_buckets_and_keys = {
            'unit-test-streamalerts': {
                b'parquet/alerts/dt=2017-08-26-14/rule_name_alerts-1304134918401.json',
                b'parquet/alerts/dt=2017-08-27-14/rule_name_alerts-1304134918401.json',
                b'alerts/2017/08/26/15/rule_name_alerts-1304134918401.json'
            },
            'unit-test-streamalert-data': {
                b'log_type_1/2017/08/26/14/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/14/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/14/test-data-11111-22222-33334.snappy',
                b'log_type_2/2017/08/26/15/test-data-11111-22222-33333.snappy',
                b'log_type_2/2017/08/26/16/test-data-11111-22222-33333.snappy',
                b'log_type_3/2017/08/26/14/test-data-11111-22222-33333.snappy',
            },
            'test-bucket-with-data': {
                b'2017/08/26/14/rule_name_alerts-1304134918401.json',
                b'2017/07/30/14/rule_name_alerts-1304134918401.json'
            }
        }

        result = self._partitioner._get_partitions_from_keys()

        assert_equal(result, expected_result)
예제 #29
0
 def test_load_schemas():
     """Shared - Config Loading - Schemas"""
     # Load from separate dir where logs.json doesn't exist
     config = load_config(conf_dir='conf_schemas')
     basic_config = basic_streamalert_config()
     assert_equal(config['logs'], basic_config['logs'])
예제 #30
0
 def test_load_schemas_logs():
     """Shared - Config Loading - Schemas and Logs.json Exist"""
     # Check if data was loaded from conf/logs.json or the schemas dir if both exist
     config = load_config(conf_dir='conf')
     # Logs.json is preferred over schemas for backwards compatibility.
     assert_equal(config['logs'], {})