コード例 #1
0
ファイル: test_athena.py プロジェクト: zachzeid/streamalert
    def setup(self):
        """Setup the AthenaClient tests"""

        self._db_name = 'test_database'
        config = load_config('tests/unit/conf/')
        prefix = config['global']['account']['prefix']

        self.client = AthenaClient(
            self._db_name, 's3://{}-streamalert-athena-results'.format(prefix),
            'unit-test')
コード例 #2
0
def get_athena_client(config):
    """Get an athena client using the current config settings

    Args:
        config (CLIConfig): Loaded StreamAlert config

    Returns:
        AthenaClient: instantiated client for performing athena actions
    """
    prefix = config['global']['account']['prefix']
    athena_config = config['lambda']['athena_partitioner_config']

    db_name = get_database_name(config)

    # Get the S3 bucket to store Athena query results
    results_bucket = athena_config.get(
        'results_bucket',
        's3://{}-streamalert-athena-results'.format(prefix)
    )

    return AthenaClient(
        db_name,
        results_bucket,
        'streamalert_cli',
        region=config['global']['account']['region']
    )
コード例 #3
0
    def _create_client(cls, db_name, results_bucket):
        if cls._ATHENA_CLIENT:
            return  # Client already created/cached

        cls._ATHENA_CLIENT = AthenaClient(db_name, results_bucket,
                                          cls.ATHENA_S3_PREFIX)

        # Check if the database exists when the client is created
        if not cls._ATHENA_CLIENT.check_database_exists():
            raise AthenaPartitionerError(
                'The \'{}\' database does not exist'.format(db_name))
コード例 #4
0
ファイル: test_athena.py プロジェクト: zachzeid/streamalert
 def test_init_fix_bucket_path(self, date_mock):
     """Athena - Fix Bucket Path"""
     date_now = datetime.utcnow()
     date_mock.utcnow.return_value = date_now
     date_format = date_now.strftime('%Y/%m/%d/%H')
     expected_path = 's3://test-streamalert-athena-results/unit-test/{}'.format(
         date_format)
     with patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}):
         client = AthenaClient(self._db_name,
                               'test-streamalert-athena-results',
                               'unit-test')
         assert_equal(client.results_path, expected_path)
コード例 #5
0
ファイル: promoter.py プロジェクト: zachzeid/streamalert
    def __init__(self):
        self._config = load_config()
        prefix = self._config['global']['account']['prefix']

        # Create the rule table class for getting staging information
        self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix))

        athena_config = self._config['lambda']['athena_partitioner_config']

        # Get the name of the athena database to access
        db_name = athena_config.get('database_name',
                                    get_database_name(self._config))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}-streamalert-athena-results'.format(prefix))

        self._athena_client = AthenaClient(db_name, results_bucket,
                                           self.ATHENA_S3_PREFIX)
        self._current_time = datetime.utcnow()
        self._staging_stats = dict()
コード例 #6
0
class RulePromoter:
    """Run queries to generate statistics on alerts."""

    ATHENA_S3_PREFIX = 'rule_promoter'
    STREAMALERT_DATABASE = '{}_streamalert'

    def __init__(self):
        self._config = load_config()
        prefix = self._config['global']['account']['prefix']

        # Create the rule table class for getting staging information
        self._rule_table = RuleTable('{}_streamalert_rules'.format(prefix))

        athena_config = self._config['lambda'][
            'athena_partition_refresh_config']

        # Get the name of the athena database to access
        db_name = athena_config.get('database_name',
                                    self.STREAMALERT_DATABASE.format(prefix))

        # Get the S3 bucket to store Athena query results
        results_bucket = athena_config.get(
            'results_bucket',
            's3://{}-streamalert-athena-results'.format(prefix))

        self._athena_client = AthenaClient(db_name, results_bucket,
                                           self.ATHENA_S3_PREFIX)
        self._current_time = datetime.utcnow()
        self._staging_stats = dict()

    def _get_staging_info(self):
        """Query the Rule table for rule staging info needed to count each rule's alerts

        Example of rule metadata returned by RuleTable.remote_rule_info():
        {
            'example_rule_name':
                {
                    'Staged': True
                    'StagedAt': datetime.datetime object,
                    'StagedUntil': '2018-04-21T02:23:13.332223Z'
                }
        }
        """
        for rule in sorted(self._rule_table.remote_rule_info):
            info = self._rule_table.remote_rule_info[rule]
            # If the rule is not staged, do not get stats on it
            if not info['Staged']:
                continue

            self._staging_stats[rule] = StagingStatistic(
                info['StagedAt'], info['StagedUntil'], self._current_time,
                rule)

        return len(self._staging_stats) != 0

    def _update_alert_count(self):
        """Transform Athena query results into alert counts for rules_engine

        Args:
            query (str): Athena query to run and wait for results

        Returns:
            dict: Representation of alert counts, where key is the rule name
                and value is the alert count (int) since this rule was staged
        """
        query = StagingStatistic.construct_compound_count_query(
            list(self._staging_stats.values()))
        LOGGER.debug('Running compound query for alert count: \'%s\'', query)
        for page, results in enumerate(
                self._athena_client.query_result_paginator(query)):
            for i, row in enumerate(results['ResultSet']['Rows']):
                if page == 0 and i == 0:  # skip header row included in first page only
                    continue

                row_values = [list(data.values())[0] for data in row['Data']]
                rule_name, alert_count = row_values[0], int(row_values[1])

                LOGGER.debug('Found %d alerts for rule \'%s\'', alert_count,
                             rule_name)

                self._staging_stats[rule_name].alert_count = alert_count

    def run(self, send_digest):
        """Perform statistic analysis of currently staged rules

        Args:
            send_digest (bool): True if the staging statistics digest should be
                published, False otherwise
        """
        if not self._get_staging_info():
            LOGGER.debug('No staged rules to promote')
            return

        self._update_alert_count()

        self._promote_rules()

        if send_digest:
            publisher = StatsPublisher(self._config, self._athena_client,
                                       self._current_time)
            publisher.publish(list(self._staging_stats.values()))
        else:
            LOGGER.debug('Staging statistics digest will not be sent')

    def _promote_rules(self):
        """Promote any rule that has not resulted in any alerts since being staged"""
        for rule in self._rules_to_be_promoted:
            LOGGER.info('Promoting rule \'%s\' at %s', rule,
                        self._current_time)
            self._rule_table.toggle_staged_state(rule, False)

    @property
    def _rules_to_be_promoted(self):
        """Returns a list of rules that are eligible for promotion"""
        return [
            rule for rule, stat in self._staging_stats.items()
            if self._current_time > stat.staged_until and stat.alert_count == 0
        ]

    @property
    def _rules_failing_promotion(self):
        """Returns a list of rules that are ineligible for promotion"""
        return [
            rule for rule, stat in self._staging_stats.items()
            if stat.alert_count != 0
        ]
コード例 #7
0
ファイル: test_athena.py プロジェクト: zachzeid/streamalert
class TestAthenaClient:
    """Test class for AthenaClient"""
    @patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'})
    @patch('boto3.client',
           Mock(side_effect=lambda c, config=None: MockAthenaClient()))
    def setup(self):
        """Setup the AthenaClient tests"""

        self._db_name = 'test_database'
        config = load_config('tests/unit/conf/')
        prefix = config['global']['account']['prefix']

        self.client = AthenaClient(
            self._db_name, 's3://{}-streamalert-athena-results'.format(prefix),
            'unit-test')

    @patch('streamalert.shared.athena.datetime')
    def test_init_fix_bucket_path(self, date_mock):
        """Athena - Fix Bucket Path"""
        date_now = datetime.utcnow()
        date_mock.utcnow.return_value = date_now
        date_format = date_now.strftime('%Y/%m/%d/%H')
        expected_path = 's3://test-streamalert-athena-results/unit-test/{}'.format(
            date_format)
        with patch.dict(os.environ, {'AWS_DEFAULT_REGION': 'us-west-1'}):
            client = AthenaClient(self._db_name,
                                  'test-streamalert-athena-results',
                                  'unit-test')
            assert_equal(client.results_path, expected_path)

    def test_unique_values_from_query(self):
        """Athena - Unique Values from Query"""
        query = {
            'ResultSet': {
                'Rows': [
                    {
                        'Data': [{
                            'VarCharValue': 'foobar'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'barfoo'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'barfoo'
                        }]
                    },
                    {
                        'Data': [{
                            'VarCharValue': 'foobarbaz'
                        }]
                    },
                ]
            }
        }
        expected_result = {'foobar', 'barfoo', 'foobarbaz'}

        result = self.client._unique_values_from_query(query)
        assert_count_equal(result, expected_result)

    def test_check_database_exists(self):
        """Athena - Check Database Exists"""
        self.client._client.results = [{
            'Data': [{
                'VarCharValue': self._db_name
            }]
        }]

        assert_true(self.client.check_database_exists())

    def test_check_database_exists_invalid(self):
        """Athena - Check Database Exists - Does Not Exist"""
        self.client._client.results = None

        assert_false(self.client.check_database_exists())

    def test_check_table_exists(self):
        """Athena - Check Table Exists"""
        self.client._client.results = [{
            'Data': [{
                'VarCharValue': 'test_table'
            }]
        }]

        assert_true(self.client.check_table_exists('test_table'))

    def test_check_table_exists_invalid(self):
        """Athena - Check Table Exists - Does Not Exist"""
        self.client._client.results = None

        assert_false(self.client.check_table_exists('test_table'))

    def test_get_table_partitions(self):
        """Athena - Get Table Partitions"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-10-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-09-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-09-10'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'dt=2018-12-11-10'
                }]
            },
        ]

        expected_result = {
            'dt=2018-12-10-10', 'dt=2018-12-09-10', 'dt=2018-12-11-10'
        }

        result = self.client.get_table_partitions('test_table')
        assert_count_equal(result, expected_result)

    def test_get_table_partitions_error(self):
        """Athena - Get Table Partitions, Exception"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError,
                      self.client.get_table_partitions, 'test_table')

    def test_drop_table(self):
        """Athena - Drop Table, Success"""
        assert_true(self.client.drop_table('test_table'))

    def test_drop_table_failure(self):
        """Athena - Drop Table, Failure"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.drop_table,
                      'test_table')

    @patch('streamalert.shared.athena.AthenaClient.drop_table')
    def test_drop_all_tables(self, drop_table_mock):
        """Athena - Drop All Tables, Success"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'table_01'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
        ]
        assert_true(self.client.drop_all_tables())
        assert_equal(drop_table_mock.call_count, 2)

    @patch('streamalert.shared.athena.AthenaClient.drop_table')
    def test_drop_all_tables_failure(self, drop_table_mock):
        """Athena - Drop All Tables, Failure"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'table_01'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_02'
                }]
            },
            {
                'Data': [{
                    'VarCharValue': 'table_03'
                }]
            },
        ]
        drop_table_mock.side_effect = [True, True, False]
        assert_false(self.client.drop_all_tables())

    def test_drop_all_tables_exception(self):
        """Athena - Drop All Tables, Exception"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.drop_all_tables)

    def test_execute_query(self):
        """Athena - Execute Query"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client._execute_query,
                      'BAD SQL')

    def test_execute_and_wait(self):
        """Athena - Execute and Wait"""
        self.client._client.results = [
            {
                'Data': [{
                    'VarCharValue': 'result'
                }]
            },
        ]
        result = self.client._execute_and_wait('SQL query')
        assert_true(result in self.client._client.query_executions)

    def test_execute_and_wait_failed(self):
        """Athena - Execute and Wait, Failed"""
        query = 'SQL query'
        self.client._client.result_state = 'FAILED'
        assert_raises(AthenaQueryExecutionError, self.client._execute_and_wait,
                      query)

    def test_query_result_paginator(self):
        """Athena - Query Result Paginator"""
        data = {'Data': [{'VarCharValue': 'result'}]}
        self.client._client.results = [
            data,
        ]

        items = list(self.client.query_result_paginator('test query'))
        assert_count_equal(items, [{'ResultSet': {'Rows': [data]}}] * 4)

    @raises(AthenaQueryExecutionError)
    def test_query_result_paginator_error(self):
        """Athena - Query Result Paginator, Exception"""
        self.client._client.raise_exception = True
        list(self.client.query_result_paginator('test query'))

    def test_run_async_query(self):
        """Athena - Run Async Query, Success"""
        assert_true(self.client.run_async_query('test query'))

    def test_run_async_query_failure(self):
        """Athena - Run Async Query, Failure"""
        self.client._client.raise_exception = True
        assert_raises(AthenaQueryExecutionError, self.client.run_async_query,
                      'test query')