Ejemplo n.º 1
0
    def test_threat_intel_match(self, mock_client):
        """Rules Engine - Threat Intel is enabled when threat_intel_match is called"""
        @rule(datatypes=['sourceAddress', 'destinationDomain', 'fileHash'],
              outputs=['s3:sample_bucket'])
        def match_rule(_): # pylint: disable=unused-variable
            """Testing dummy rule"""
            return True

        mock_client.return_value = MockDynamoDBClient()
        toggled_config = self.config
        toggled_config['global']['threat_intel']['enabled'] = True
        toggled_config['global']['threat_intel']['dynamodb_table'] = 'test_table_name'

        new_rules_engine = RulesEngine(toggled_config)
        records = mock_normalized_records()
        alerts = new_rules_engine.threat_intel_match(records)
        assert_equal(len(alerts), 2)
Ejemplo n.º 2
0
    def test_process_allow_multi_around_normalization(self, mock_client):
        """Rules Engine - Threat Intel is enabled run multi-round_normalization"""
        @rule(datatypes=['fileHash'], outputs=['s3:sample_bucket'])
        def match_file_hash(rec):  # pylint: disable=unused-variable
            """Testing dummy rule to match file hash"""
            return 'streamalert:ioc' in rec and 'md5' in rec['streamalert:ioc']

        @rule(datatypes=['fileHash'], outputs=['s3:sample_bucket'])
        def match_file_hash_again(_):  # pylint: disable=unused-variable
            """Testing dummy rule to match file hash again"""
            return False

        @rule(datatypes=['fileHash', 'sourceDomain'],
              outputs=['s3:sample_bucket'])
        def match_source_domain(rec):  # pylint: disable=unused-variable
            """Testing dummy rule to match source domain and file hash"""
            return 'streamalert:ioc' in rec

        mock_client.return_value = MockDynamoDBClient()
        toggled_config = self.config
        toggled_config['global']['threat_intel']['enabled'] = True
        toggled_config['global']['threat_intel'][
            'dynamodb_table'] = 'test_table_name'

        new_rules_engine = RulesEngine(toggled_config)
        kinesis_data = {
            "Field1": {
                "SubField1": {
                    "key1": 17,
                    "key2_md5": "md5-of-file",
                    "key3_source_domain": "evil.com"
                },
                "SubField2": 1
            },
            "Field2": {
                "Authentication": {}
            },
            "Field3": {},
            "Field4": {}
        }

        kinesis_data = json.dumps(kinesis_data)
        service, entity = 'kinesis', 'test_stream_threat_intel'
        raw_record = make_kinesis_raw_record(entity, kinesis_data)
        payload = load_and_classify_payload(toggled_config, service, entity,
                                            raw_record)
        alerts, normalized_records = new_rules_engine.run(payload)

        # Two testing rules are for threat intelligence matching. So no alert will be
        # generated before threat intel takes effect.
        assert_equal(len(alerts), 0)

        # One record will be normalized once by two different rules with different
        # normalization keys.
        assert_equal(len(normalized_records), 1)
        assert_equal(
            normalized_records[0].
            pre_parsed_record['streamalert:normalization'].keys(),
            ['fileHash', 'sourceDomain'])

        # Pass normalized records to threat intel engine.
        alerts_from_threat_intel = new_rules_engine.threat_intel_match(
            normalized_records)
        assert_equal(len(alerts_from_threat_intel), 2)
        assert_equal(alerts_from_threat_intel[0].rule_name, 'match_file_hash')
        assert_equal(alerts_from_threat_intel[1].rule_name,
                     'match_source_domain')
Ejemplo n.º 3
0
class StreamAlert(object):
    """Wrapper class for handling StreamAlert classification and processing"""
    config = {}

    def __init__(self, context):
        """Initializer

        Args:
            context (dict): An AWS context object which provides metadata on the currently
                executing lambda function.
        """
        # Load the config. Validation occurs during load, which will
        # raise exceptions on any ConfigErrors
        StreamAlert.config = StreamAlert.config or load_config()

        # Load the environment from the context arn
        self.env = load_env(context)

        # Instantiate the send_alerts here to handle sending the triggered alerts to the
        # alert processor
        self.alert_forwarder = AlertForwarder()

        # Instantiate a classifier that is used for this run
        self.classifier = StreamClassifier(config=self.config)

        self._failed_record_count = 0
        self._processed_record_count = 0
        self._processed_size = 0
        self._alerts = []

        rule_import_paths = [
            item for location in {'rule_locations', 'matcher_locations'}
            for item in self.config['global']['general'][location]
        ]

        # Create an instance of the StreamRules class that gets cached in the
        # StreamAlert class as an instance property
        self._rules_engine = RulesEngine(self.config, *rule_import_paths)

        # Firehose client attribute
        self._firehose_client = None

    def run(self, event):
        """StreamAlert Lambda function handler.

        Loads the configuration for the StreamAlert function which contains
        available data sources, log schemas, normalized types, and outputs.
        Classifies logs sent into a parsed type.
        Matches records against rules.

        Args:
            event (dict): An AWS event mapped to a specific source/entity
                containing data read by Lambda.

        Returns:
            bool: True if all logs being parsed match a schema
        """
        records = event.get('Records', [])
        LOGGER.debug('Number of incoming records: %d', len(records))
        if not records:
            return False

        firehose_config = self.config['global'].get('infrastructure',
                                                    {}).get('firehose', {})
        if firehose_config.get('enabled'):
            self._firehose_client = StreamAlertFirehose(
                self.env['lambda_region'], firehose_config,
                self.config['logs'])

        payload_with_normalized_records = []
        for raw_record in records:
            # Get the service and entity from the payload. If the service/entity
            # is not in our config, log and error and go onto the next record
            service, entity = self.classifier.extract_service_and_entity(
                raw_record)
            if not service:
                LOGGER.error(
                    'No valid service found in payload\'s raw record. Skipping '
                    'record: %s', raw_record)
                continue

            if not entity:
                LOGGER.error(
                    'Unable to extract entity from payload\'s raw record for service %s. '
                    'Skipping record: %s', service, raw_record)
                continue

            # Cache the log sources for this service and entity on the classifier
            if not self.classifier.load_sources(service, entity):
                continue

            # Create the StreamPayload to use for encapsulating parsed info
            payload = load_stream_payload(service, entity, raw_record)
            if not payload:
                continue

            payload_with_normalized_records.extend(
                self._process_alerts(payload))

        # Log normalized records metric
        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.NORMALIZED_RECORDS,
                                len(payload_with_normalized_records))

        # Apply Threat Intel to normalized records in the end of Rule Processor invocation
        record_alerts = self._rules_engine.threat_intel_match(
            payload_with_normalized_records)
        self._alerts.extend(record_alerts)
        if record_alerts:
            self.alert_forwarder.send_alerts(record_alerts)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TOTAL_RECORDS,
                                self._processed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME,
                                MetricLogger.TOTAL_PROCESSED_SIZE,
                                self._processed_size)

        LOGGER.debug('Invalid record count: %d', self._failed_record_count)

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.FAILED_PARSES,
                                self._failed_record_count)

        LOGGER.debug('%s alerts triggered', len(self._alerts))

        MetricLogger.log_metric(FUNCTION_NAME, MetricLogger.TRIGGERED_ALERTS,
                                len(self._alerts))

        # Check if debugging logging is on before json dumping alerts since
        # this can be time consuming if there are a lot of alerts
        if self._alerts and LOGGER.isEnabledFor(LOG_LEVEL_DEBUG):
            LOGGER.debug(
                'Alerts:\n%s',
                json.dumps([alert.output_dict() for alert in self._alerts],
                           indent=2,
                           sort_keys=True))

        if self._firehose_client:
            self._firehose_client.send()

        # Only log rule info here if this is not running tests
        # During testing, this gets logged at the end and printing here could be confusing
        # since stress testing calls this method multiple times
        if self.env['lambda_alias'] != 'development':
            stats.print_rule_stats(True)

        return self._failed_record_count == 0

    @property
    def alerts(self):
        """Returns list of Alert instances (useful for testing)."""
        return self._alerts

    def _process_alerts(self, payload):
        """Run the record through the rules, saving any alerts and forwarding them to Dynamo.

        Args:
            payload (StreamPayload): StreamAlert payload object being processed
        """
        payload_with_normalized_records = []
        for record in payload.pre_parse():
            # Increment the processed size using the length of this record
            self._processed_size += len(record.pre_parsed_record)
            self.classifier.classify_record(record)
            if not record.valid:
                if self.env['lambda_alias'] != 'development':
                    LOGGER.error(
                        'Record does not match any defined schemas: %s\n%s',
                        record, record.pre_parsed_record)

                self._failed_record_count += 1
                continue

            # Increment the total processed records to get an accurate assessment of throughput
            self._processed_record_count += len(record.records)

            LOGGER.debug(
                'Classified and Parsed Payload: <Valid: %s, Log Source: %s, Entity: %s>',
                record.valid, record.log_source, record.entity)

            record_alerts, normalized_records = self._rules_engine.run(record)

            payload_with_normalized_records.extend(normalized_records)

            LOGGER.debug(
                'Processed %d valid record(s) that resulted in %d alert(s).',
                len(payload.records), len(record_alerts))

            # Add all parsed records to the categorized payload dict only if Firehose is enabled
            if self._firehose_client:
                # Only send payloads with enabled log sources
                if self._firehose_client.enabled_log_source(
                        payload.log_source):
                    self._firehose_client.categorized_payloads[
                        payload.log_source].extend(payload.records)

            if not record_alerts:
                continue

            # Extend the list of alerts with any new ones so they can be returned
            self._alerts.extend(record_alerts)

            self.alert_forwarder.send_alerts(record_alerts)

        return payload_with_normalized_records