def publish_match_message( self, content_id: str, content_hash: str, matches: t.List[IndexMatch], sns_client: SNSClient, topic_arn: str, ): """ Creates banked signal objects and publishes one message for a list of matches to SNS. """ banked_signals = [] for match in matches: for privacy_group_id in match.metadata.get("privacy_groups", []): banked_signal = BankedSignal( str(match.metadata["id"]), str(privacy_group_id), str(match.metadata["source"]), ) for tag in match.metadata["tags"].get(privacy_group_id, []): banked_signal.add_classification(tag) banked_signals.append(banked_signal) match_message = MatchMessage( content_key=content_id, content_hash=content_hash, matching_banked_signals=banked_signals, ) sns_client.publish(TopicArn=topic_arn, Message=match_message.to_aws_json())
def publish_match_message( self, content_id: str, content_hash: str, matches: t.List[IndexMatch[t.List[BaseIndexMetadata]]], sns_client: SNSClient, topic_arn: str, ): """ Creates banked signal objects and publishes one message for a list of matches to SNS. """ banked_signals = [] for match in matches: for metadata_obj in match.metadata: if metadata_obj.get_source( ) == THREAT_EXCHANGE_SOURCE_SHORT_CODE: metadata_obj = t.cast(ThreatExchangeIndicatorIndexMetadata, metadata_obj) banked_signal = BankedSignal( str(metadata_obj.indicator_id), str(metadata_obj.privacy_group), str(metadata_obj.get_source()), ) for tag in metadata_obj.tags: banked_signal.add_classification(tag) banked_signals.append(banked_signal) elif metadata_obj.get_source() == BANKS_SOURCE_SHORT_CODE: metadata_obj = t.cast(BankedSignalIndexMetadata, metadata_obj) bank_member = self.banks_table.get_bank_member( bank_member_id=metadata_obj.bank_member_id) banked_signal = BankedSignal( metadata_obj.bank_member_id, bank_member.bank_id, metadata_obj.get_source(), ) # TODO: This would do good with caching. bank = self.banks_table.get_bank( bank_id=bank_member.bank_id) for tag in set.union(bank_member.bank_member_tags, bank.bank_tags): banked_signal.add_classification(tag) banked_signals.append(banked_signal) match_message = MatchMessage( content_key=content_id, content_hash=content_hash, matching_banked_signals=banked_signals, ) sns_client.publish(TopicArn=topic_arn, Message=match_message.to_aws_json())
def test_webhook_action_repalcement(self): content_id = "cid1" content_hash = "0374f1g34f12g34f8" banked_signal = BankedSignal( banked_content_id="4169895076385542", bank_id="303636684709969", bank_source="te", ) match_message = MatchMessage( content_id, content_hash, [banked_signal], ) action_performers = [ performer_class( name="EnqueueForReview", url= "https://webhook.site/d0dbb19d-2a6f-40be-ad4d-fa9c8b34c8df/<content-id>", headers='{"Connection":"keep-alive"}', # monitoring page: # https://webhook.site/#!/d0dbb19d-2a6f-40be-ad4d-fa9c8b34c8df ) for performer_class in WebhookActionPerformer.__subclasses__() ] for action_performer in action_performers: action_performer.perform_action(match_message)
def perform_action(self, match_message: MatchMessage) -> None: parsed_url = self.url for ( replacement_str, replacement_func, ) in WEBHOOK_ACTION_PERFORMER_REPLACEMENTS.items(): parsed_url = parsed_url.replace(replacement_str, replacement_func(match_message)) self.call(parsed_url, data=json.dumps(match_message.to_aws()))
def publish_match_message( self, content_id: str, content_hash: str, matches: t.List[IndexMatch], sns_client: SNSClient, topic_arn: str, ): """ Creates banked signal objects and publishes one message for a list of matches to SNS. """ banked_signals = [] for match in matches: for metadata_obj in match.metadata: if metadata_obj.get_source( ) == THREAT_EXCHANGE_SOURCE_SHORT_CODE: banked_signal = BankedSignal( str(metadata_obj.indicator_id), str(metadata_obj.privacy_group), str(metadata_obj.get_source()), ) for tag in metadata_obj.tags: banked_signal.add_classification(tag) banked_signals.append(banked_signal) elif metadata_obj.get_source() == BANKS_SOURCE_SHORT_CODE: banked_signal = BankedSignal( metadata_obj.signal_id, metadata_obj.bank_member_id, metadata_obj.get_source(), ) banked_signals.append(banked_signal) match_message = MatchMessage( content_key=content_id, content_hash=content_hash, matching_banked_signals=banked_signals, ) sns_client.publish(TopicArn=topic_arn, Message=match_message.to_aws_json())
def lambda_handler(event, context): """ This lambda is called when one or more matches are found. If a single hash matches multiple datasets, this will be called only once. Action labels are generated for each match message, then an action is performed corresponding to each action label. """ config = ActionEvaluatorConfig.get() for sqs_record in event["Records"]: # TODO research max # sqs records / lambda_handler invocation sqs_record_body = json.loads(sqs_record["body"]) logger.info("sqs record body %s", sqs_record["body"]) match_message = MatchMessage.from_aws_json(sqs_record_body["Message"]) logger.info("Evaluating match_message: %s", match_message) action_rules = get_action_rules() logger.info("Evaluating against action_rules: %s", action_rules) submitted_content = ContentObject.get_from_content_id( config.dynamo_db_table, match_message.content_key) action_label_to_action_rules = get_actions_to_take( match_message, action_rules, submitted_content.additional_fields, ) action_labels = list(action_label_to_action_rules.keys()) for action_label in action_labels: action_message = ActionMessage.from_match_message_action_label_action_rules_and_additional_fields( match_message, action_label, action_label_to_action_rules[action_label], list(submitted_content.additional_fields), ) logger.info("Sending Action message: %s", action_message) config.sqs_client.send_message( QueueUrl=config.actions_queue_url, MessageBody=action_message.to_aws_json(), ) writeback_message = WritebackMessage.from_match_message_and_type( match_message, WritebackTypes.SawThisToo) writeback_message.send_to_queue(config.sqs_client, config.writeback_queue_url) return {"evaluation_completed": "true"}
def lambda_handler(event, context): """ Listens to events on a queue attached to the match SNS topic and increments various counters. Presently supported: --- 1. split of hash and matches by privacy group """ records_table = dynamodb.Table(DYNAMODB_TABLE) counters = defaultdict(lambda: 0) for sqs_record in event["Records"]: sqs_record_body = json.loads(sqs_record["body"]) match_message = MatchMessage.from_aws_json(sqs_record_body["Message"]) for signal in match_message.matching_banked_signals: privacy_group_id = signal.bank_id counters[privacy_group_id] += 1 logger.debug("Flushing %s", counters) # Flush counters to dynamodb MatchByPrivacyGroupCounter.increment_counts(records_table, counters)
if __name__ == "__main__": # For basic debugging HMAConfig.initialize(os.environ["CONFIG_TABLE_NAME"]) action_rules = get_action_rules() match_message = MatchMessage( content_key="m2", content_hash= "361da9e6cf1b72f5cea0344e5bb6e70939f4c70328ace762529cac704297354a", matching_banked_signals=[ BankedSignal( banked_content_id="3070359009741438", bank_id="258601789084078", bank_source="te", classifications={ BankedContentIDClassificationLabel( value="258601789084078"), ClassificationLabel(value="true_positive"), BankSourceClassificationLabel(value="te"), BankIDClassificationLabel(value="3534976909868947"), }, ) ], ) event = { "Records": [{ "body": json.dumps({"Message": match_message.to_aws_json()}) }]
def test_get_action_labels(self): enqueue_for_review_action_label = ActionLabel("EnqueueForReview") bank_id = "12345" banked_signal_without_foo = BankedSignal("67890", bank_id, "Test") banked_signal_without_foo.add_classification("Bar") banked_signal_without_foo.add_classification("Xyz") banked_signal_with_foo = BankedSignal("67890", bank_id, "Test") banked_signal_with_foo.add_classification("Foo") banked_signal_with_foo.add_classification("Bar") banked_signal_with_foo.add_classification("Xyz") match_message_without_foo = MatchMessage("key", "hash", [banked_signal_without_foo]) match_message_with_foo = MatchMessage("key", "hash", [banked_signal_with_foo]) action_rules = [ ActionRule( enqueue_for_review_action_label.value, enqueue_for_review_action_label, set([BankIDClassificationLabel(bank_id)]), set([ClassificationLabel("Foo")]), ) ] action_label_to_action_rules: t.Dict[ ActionLabel, t.List[ActionRule]] = get_actions_to_take( match_message_without_foo, action_rules, set()) assert len(action_label_to_action_rules) == 1 self.assertIn( enqueue_for_review_action_label, action_label_to_action_rules, "enqueue_for_review_action_label should be in action_label_to_action_rules", ) action_label_to_action_rules = get_actions_to_take( match_message_with_foo, action_rules, set()) assert len(action_label_to_action_rules) == 0 enqueue_mini_castle_for_review_action_label = ActionLabel( "EnqueueMiniCastleForReview") enqueue_sailboat_for_review_action_label = ActionLabel( "EnqueueSailboatForReview") action_rules = [ ActionRule( name="Enqueue Mini-Castle for Review", action_label=enqueue_mini_castle_for_review_action_label, must_have_labels=set([ BankIDClassificationLabel("303636684709969"), ClassificationLabel("true_positive"), ]), must_not_have_labels=set( [BankedContentIDClassificationLabel("3364504410306721")]), ), ActionRule( name="Enqueue Sailboat for Review", action_label=enqueue_sailboat_for_review_action_label, must_have_labels=set([ BankIDClassificationLabel("303636684709969"), ClassificationLabel("true_positive"), BankedContentIDClassificationLabel("3364504410306721"), ]), must_not_have_labels=set(), ), ] mini_castle_banked_signal = BankedSignal( banked_content_id="4169895076385542", bank_id="303636684709969", bank_source="te", ) mini_castle_banked_signal.add_classification("true_positive") mini_castle_match_message = MatchMessage( content_key="images/mini-castle.jpg", content_hash= "361da9e6cf1b72f5cea0344e5bb6e70939f4c70328ace762529cac704297354a", matching_banked_signals=[mini_castle_banked_signal], ) sailboat_banked_signal = BankedSignal( banked_content_id="3364504410306721", bank_id="303636684709969", bank_source="te", ) sailboat_banked_signal.add_classification("true_positive") sailboat_match_message = MatchMessage( content_key="images/sailboat-mast-and-sun.jpg", content_hash= "388ff5e1084efef10096df9cb969296dff2b04d67a94065ecd292129ef6b1090", matching_banked_signals=[sailboat_banked_signal], ) action_label_to_action_rules = get_actions_to_take( mini_castle_match_message, action_rules, set()) assert len(action_label_to_action_rules) == 1 self.assertIn( enqueue_mini_castle_for_review_action_label, action_label_to_action_rules, "enqueue_mini_castle_for_review_action_label should be in action_label_to_action_rules", ) action_label_to_action_rules = get_actions_to_take( sailboat_match_message, action_rules, set()) assert len(action_label_to_action_rules) == 1 self.assertIn( enqueue_sailboat_for_review_action_label, action_label_to_action_rules, "enqueue_sailboat_for_review_action_label should be in action_label_to_action_rules", )
class WritebackerTestCase(unittest.TestCase): banked_signals = [ BankedSignal("2862392437204724", "pg 4", "te"), BankedSignal("4194946153908639", "pg 4", "te"), BankedSignal("3027465034605137", "pg 3", "te"), BankedSignal("evil.jpg", "bank 4", "non-te-source"), ] match_message = MatchMessage("key", "hash", banked_signals) # Writebacks are enabled for the trustworth privacy group not for # the untrustworthy one configs = [ ThreatExchangeConfig("pg 4", True, "Trustworthy PG", "test description", True, True, True), ThreatExchangeConfig("pg 3", True, "UnTrustworthy PG", "test description", True, False, True), ] for config in configs: hmaconfig.mock_create_config(config) def test_saw_this_too(self): os.environ["MOCK_TE_API"] = "True" os.environ["CONFIG_TABLE_NAME"] = "test-HMAConfig" writeback = WritebackTypes.SawThisToo writeback_message = WritebackMessage.from_match_message_and_type( self.match_message, writeback) event = {"Records": [{"body": writeback_message.to_aws_json()}]} result = lambda_handler(event, None) assert result == { "writebacks_performed": { "te": [ "Reacted SAW_THIS_TOO to descriptor a2|2862392437204724\nReacted SAW_THIS_TOO to descriptor a3|2862392437204724", "Reacted SAW_THIS_TOO to descriptor a2|4194946153908639\nReacted SAW_THIS_TOO to descriptor a3|4194946153908639", "No writeback performed for banked content id 3027465034605137 becuase writebacks were disabled", ] } } os.environ["MOCK_TE_API"] = "False" def test_false_positive(self): os.environ["MOCK_TE_API"] = "True" os.environ["CONFIG_TABLE_NAME"] = "test-HMAConfig" writeback = WritebackTypes.FalsePositive writeback_message = WritebackMessage.from_match_message_and_type( self.match_message, writeback) event = {"Records": [{"body": writeback_message.to_aws_json()}]} result = lambda_handler(event, None) assert result == { "writebacks_performed": { "te": [ "Reacted DISAGREE_WITH_TAGS to descriptor a2|2862392437204724\nReacted DISAGREE_WITH_TAGS to descriptor a3|2862392437204724", "Reacted DISAGREE_WITH_TAGS to descriptor a2|4194946153908639\nReacted DISAGREE_WITH_TAGS to descriptor a3|4194946153908639", "No writeback performed for banked content id 3027465034605137 becuase writebacks were disabled", ] } } os.environ["MOCK_TE_API"] = "False" def test_true_positve(self): os.environ["MOCK_TE_API"] = "True" os.environ["CONFIG_TABLE_NAME"] = "test-HMAConfig" writeback = WritebackTypes.TruePositive writeback_message = WritebackMessage.from_match_message_and_type( self.match_message, writeback) event = {"Records": [{"body": writeback_message.to_aws_json()}]} result = lambda_handler(event, None) assert result == { "writebacks_performed": { "te": [ "Wrote back TruePositive for indicator 2862392437204724\nBuilt descriptor a1|2862392437204724 with privacy groups pg 4", "Wrote back TruePositive for indicator 4194946153908639\nBuilt descriptor a1|4194946153908639 with privacy groups pg 4", "No writeback performed for banked content id 3027465034605137 becuase writebacks were disabled", ] } } os.environ["MOCK_TE_API"] = "False" def test_remove_opinion(self): os.environ["MOCK_TE_API"] = "True" os.environ["CONFIG_TABLE_NAME"] = "test-HMAConfig" writeback = WritebackTypes.RemoveOpinion writeback_message = WritebackMessage.from_match_message_and_type( self.match_message, writeback) event = {"Records": [{"body": writeback_message.to_aws_json()}]} result = lambda_handler(event, None) assert result == { "writebacks_performed": { "te": [ "\n".join(( "Deleted decriptor a1|2862392437204724 for indicator 2862392437204724", "Removed reaction DISAGREE_WITH_TAGS from descriptor a2|2862392437204724", "Removed reaction DISAGREE_WITH_TAGS from descriptor a3|2862392437204724", )), "\n".join(( "Deleted decriptor a1|4194946153908639 for indicator 4194946153908639", "Removed reaction DISAGREE_WITH_TAGS from descriptor a2|4194946153908639", "Removed reaction DISAGREE_WITH_TAGS from descriptor a3|4194946153908639", )), "No writeback performed for banked content id 3027465034605137 becuase writebacks were disabled", ] } } os.environ["MOCK_TE_API"] = "False"
def lambda_handler(event, context): """ TODO/FIXME migrate this lambda to be a part of matcher.py Listens to SQS events fired when new hash is generated. Loads the index stored in an S3 bucket and looks for a match. As per the default configuration - the index data bucket is INDEXES_BUCKET_NAME - the key name must be S3BackedPDQIndex._get_index_s3_key() When matched, publishes a notification to an SNS endpoint. Note this is in contrast with hasher and indexer. They publish to SQS directly. Publishing to SQS implies there can be only one consumer. Because, here, in the matcher, we publish to SNS, we can plug multiple queues behind it and profit! """ records_table = dynamodb.Table(DYNAMODB_TABLE) hash_index: PDQIndex = get_index(INDEXES_BUCKET_NAME) logger.info("loaded_hash_index") for sqs_record in event["Records"]: message = json.loads(sqs_record["body"]) if message.get("Event") == "TestEvent": logger.info("Disregarding Test Event") continue hash_str = message["hash"] key = message["key"] current_datetime = datetime.datetime.now() with metrics.timer(metrics.names.pdq_matcher_lambda.search_index): results = hash_index.query(hash_str) if results: match_ids = [] matching_banked_signals: t.List[BankedSignal] = [] for match in results: metadata = match.metadata logger.info( "Match found for key: %s, hash %s -> %s", key, hash_str, metadata ) privacy_group_list = metadata.get("privacy_groups", []) metadata["privacy_groups"] = list( filter( lambda x: get_privacy_group_matcher_active( str(x), time.time() // CACHED_TIME, # CACHED_TIME default to 300 seconds, this will convert time.time() to an int parameter which changes every 300 seconds ), privacy_group_list, ) ) if metadata["privacy_groups"]: signal_id = str(metadata["id"]) with metrics.timer( metrics.names.pdq_matcher_lambda.write_match_record ): # TODO: Add source (threatexchange) tags to match record MatchRecord( key, PdqSignal, hash_str, current_datetime, signal_id, metadata["source"], metadata["hash"], ).write_to_table(records_table) for pg in metadata.get("privacy_groups", []): # Only update the metadata if it is not found in the table # once intally created it is the fetcher's job to keep the item up to date PDQSignalMetadata( signal_id, pg, current_datetime, metadata["source"], metadata["hash"], metadata["tags"].get(pg, []), ).write_to_table_if_not_found(records_table) match_ids.append(signal_id) # TODO: change naming upstream and here from privacy_group[s] # to dataset[s] for privacy_group in metadata.get("privacy_groups", []): banked_signal = BankedSignal( str(signal_id), str(privacy_group), str(metadata["source"]) ) for tag in metadata["tags"].get(privacy_group, []): banked_signal.add_classification(tag) matching_banked_signals.append(banked_signal) # TODO: Add source (threatexchange) tags to match message if matching_banked_signals: match_message = MatchMessage( content_key=key, content_hash=hash_str, matching_banked_signals=matching_banked_signals, ) logger.info(f"Publishing match_message: {match_message}") # Publish one message for the set of matches. sns_client.publish( TopicArn=OUTPUT_TOPIC_ARN, Message=match_message.to_aws_json() ) else: logger.info(f"No matches found for key: {key} hash: {hash_str}") metrics.flush()