def test_order_of_signals_is_chronological(self): with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table(), get_default_signal_type_mapping()) bank_id, bank_member_id = self._create_bank_and_bank_member() signals = [ table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoMD5Signal, signal_value="A VIDEO MD5 SIGNAL. WILTY?" + str(random.random()), ) for _ in range(20) ] signal_ids_in_order = list( map(lambda s: s.signal_id, sorted(signals, key=lambda x: x.updated_at))) to_process_signal_ids = [ signal.signal_id for signal in table_manager.get_bank_member_signals_to_process_page( signal_type=VideoMD5Signal).items ] self.assertListEqual(signal_ids_in_order, to_process_signal_ids)
def remove_bank_member( banks_table: BanksTable, bank_member_id: str, ): """ Remove bank member. Marks the member as removed and all its signals are removed from the GSI used to build HMA indexes. NOTE: If we ever start incremental updates to HMA indexes, removing bank members will stop working. """ banks_table.remove_bank_member_signals_to_process( bank_member_id=bank_member_id) banks_table.remove_bank_member(bank_member_id=bank_member_id)
def _create_bank_and_bank_member(self) -> t.Tuple[str, str]: table_manager = BanksTable(self.get_table()) bank = table_manager.create_bank("TEST_BANK", "Test bank description") bank_member = table_manager.add_bank_member( bank_id=bank.bank_id, content_type=VideoContent, raw_content=None, storage_bucket="hma-test-media", storage_key="irrrelevant", notes="", ) return (bank.bank_id, bank_member.bank_member_id)
def lambda_handler(event, context): """ Runs on a schedule. On each run, gets all data files for ALL_INDEXABLE_SIGNAL_TYPES from s3, converts the raw data file into an index and writes to an output S3 bucket. As per the default configuration, the bucket must be - the hashing data bucket eg. dipanjanm-hashing-<...> - the key name must be in the ThreatExchange folder (eg. threat_exchange_data/) - the key name must return a signal_type in ThreatUpdateS3Store.get_signal_type_from_object_key """ # Note: even though we know which files were updated, threatexchange indexes # do not yet allow adding new entries. So, we must do a full rebuild. So, we # only end up using the signal types that were updated, not the actual files # that changed. s3_config = S3ThreatDataConfig( threat_exchange_data_bucket_name=THREAT_EXCHANGE_DATA_BUCKET_NAME, threat_exchange_data_folder=THREAT_EXCHANGE_DATA_FOLDER, ) banks_table = BanksTable(dynamodb.Table(BANKS_TABLE)) for signal_type in ALL_INDEXABLE_SIGNAL_TYPES: adapter_class = _ADAPTER_MAPPING[signal_type] data_files = adapter_class( config=s3_config, metrics_logger=metrics.names.indexer).load_data() bank_data = get_all_bank_hash_rows(signal_type, banks_table) with metrics.timer(metrics.names.indexer.merge_datafiles): logger.info(f"Merging {signal_type} Hash files") # go from dict[filename, list<hash rows>] → list<hash rows> flattened_data = [ hash_row for file_ in data_files.values() for hash_row in file_ ] merged_data = functools.reduce(merge_hash_rows_on_hash_value, flattened_data + bank_data, {}).values() with metrics.timer(metrics.names.indexer.build_index): logger.info(f"Rebuilding {signal_type} Index") for index_class in INDEX_MAPPING[signal_type]: index: S3BackedInstrumentedIndexMixin = index_class.build( merged_data) logger.info( f"Putting {signal_type} index in S3 for index {index.get_index_class_name()}" ) index.save(bucket_name=INDEXES_BUCKET_NAME) metrics.flush() logger.info("Index updates complete")
def _create_200_members(self) -> str: """Create a bank, 200 members and return bank_id.""" table_manager = BanksTable(self.get_table(), get_default_signal_type_mapping()) bank = table_manager.create_bank("TEST_BANK", "TEST BANK Description") for _ in range(200): table_manager.add_bank_member( bank_id=bank.bank_id, content_type=PhotoContent, raw_content=None, storage_bucket="hma-test-media", storage_key="videos/breaking-news.mp4", notes="", ) return bank.bank_id
def test_single_signal_is_retrieved(self): with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table()) bank_id, bank_member_id = self._create_bank_and_bank_member() bank_member_signal = table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoTmkPdqfSignal, signal_value="A VIDEO TMK PDQF SIGNAL. WILTY?", ) # expect this to now be available to process to_process = table_manager.get_bank_member_signals_to_process_page( signal_type=VideoTmkPdqfSignal) self.assertEqual(len(to_process.items), 1) self.assertEqual(bank_member_signal.signal_id, to_process.items[0].signal_id)
def test_multiple_signals_are_retrieved(self): with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table()) bank_id, bank_member_id = self._create_bank_and_bank_member() signal_ids = [ table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoTmkPdqfSignal, signal_value="A VIDEO TMK PDQF SIGNAL. WILTY?" + str(random.random()), ).signal_id for _ in range(20) ] to_process_signal_ids = [ signal.signal_id for signal in table_manager.get_bank_member_signals_to_process_page( signal_type=VideoTmkPdqfSignal).items ] self.assertListEqual(signal_ids, to_process_signal_ids)
def test_order_of_signals_multi_page(self): with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table(), get_default_signal_type_mapping()) bank_id, bank_member_id = self._create_bank_and_bank_member() signals = [ table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoMD5Signal, signal_value="A VIDEO TMK PDQF SIGNAL. WILTY?" + str(random.random()), ) for _ in range(20) ] signal_ids_in_order = list( map(lambda s: s.signal_id, sorted(signals, key=lambda x: x.updated_at))) queried_signal_ids = [] exclusive_start_key = None while True: response = table_manager.get_bank_member_signals_to_process_page( signal_type=VideoMD5Signal, limit=4, exclusive_start_key=exclusive_start_key, ) exclusive_start_key = response.last_evaluated_key queried_signal_ids += [ signal.signal_id for signal in response.items ] if not response.has_next_page(): break self.assertListEqual(signal_ids_in_order, queried_signal_ids)
def _create_banks(self): self.table_manager = BanksTable( self.get_table(), get_default_signal_type_mapping() ) self.active_bank = self.table_manager.create_bank("TEST_BANK", "Is Active") self.active_bank_member = self.table_manager.add_bank_member( bank_id=self.active_bank.bank_id, content_type=PhotoContent, raw_content=None, storage_bucket=None, storage_key=None, notes=None, ) self.table_manager.update_bank( bank_id=self.active_bank.bank_id, bank_name=self.active_bank.bank_name, bank_description=self.active_bank.bank_description, is_active=True, ) self.inactive_bank = self.table_manager.create_bank( "TEST_BANK_2", "Is Inactive" ) self.table_manager.update_bank( bank_id=self.inactive_bank.bank_id, bank_name=self.inactive_bank.bank_name, bank_description=self.inactive_bank.bank_description, is_active=False, ) self.inactive_bank_member = self.table_manager.add_bank_member( bank_id=self.inactive_bank.bank_id, content_type=PhotoContent, raw_content=None, storage_bucket=None, storage_key=None, notes=None, )
def get_banked_signal_details( banks_table: BanksTable, signal_id: str, signal_source: str, ) -> t.List[BankedSignalDetailsMetadata]: if not signal_id or not signal_source or signal_source != BANKS_SOURCE_SHORT_CODE: return [] return [ BankedSignalDetailsMetadata( bank_member_id=bank_member_signal.bank_member_id, bank_id=bank_member_signal.bank_id, ) for bank_member_signal in banks_table.get_bank_member_signal_from_id( signal_id) ]
def test_bank_member_removes(self): with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table()) bank_id, bank_member_id = self._create_bank_and_bank_member() bank_member_signal_1 = table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoTmkPdqfSignal, signal_value="A VIDEO TMK PDQF SIGNAL. WILTY?", ) bank_member_signal_2 = table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoTmkPdqfSignal, signal_value="ANOTHER VIDEO TMK PDQF SIGNAL. WILTY?", ) bank_member_signal_3 = table_manager.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=VideoTmkPdqfSignal, signal_value="An ANOTHER VIDEO TMK PDQF SIGNAL. WILTY?", ) # expect this to now be available to process to_process = table_manager.get_bank_member_signals_to_process_page( signal_type=VideoTmkPdqfSignal) self.assertEqual(len(to_process.items), 3) table_manager.remove_bank_member_signals_to_process( bank_member_id=bank_member_id) # expect this to now be available to process to_process = table_manager.get_bank_member_signals_to_process_page( signal_type=VideoTmkPdqfSignal) self.assertEqual(len(to_process.items), 0)
def get_bank_api(bank_table: Table) -> bottle.Bottle: """ Closure for dependencies of the bank API """ bank_api = bottle.Bottle() table_manager = BanksTable(table=bank_table) @bank_api.get("/get-all-banks", apply=[jsoninator]) def get_all_banks() -> AllBanksEnvelope: """ Get all banks. """ return AllBanksEnvelope(banks=table_manager.get_all_banks()) @bank_api.get("/get-bank/<bank_id>", apply=[jsoninator]) def get_bank(bank_id=None) -> Bank: """ Get a specific bank from a bank_id. """ bank = table_manager.get_bank(bank_id=bank_id) return bank @bank_api.post("/create-bank", apply=[jsoninator]) def create_bank() -> Bank: """ Create a bank using only the name and description. """ return table_manager.create_bank( bank_name=bottle.request.json["bank_name"], bank_description=bottle.request.json["bank_description"], ) @bank_api.post("/update-bank/<bank_id>", apply=[jsoninator]) def update_bank(bank_id=None) -> Bank: """ Update name and description for a bank_id. """ return table_manager.update_bank( bank_id=bank_id, bank_name=bottle.request.json["bank_name"], bank_description=bottle.request.json["bank_description"], ) return bank_api
def add_detached_bank_member_signal( banks_table: BanksTable, bank_id: str, content_type: t.Type[ContentType], signal_type: t.Type[SignalType], signal_value: str, ) -> BankMemberSignal: """ Add a bank member signal without a BankMember. Will deduplicate a signal_value + signal_type tuple before writing to the database. Will make signals available for processing into indices. """ return banks_table.add_detached_bank_member_signal( bank_id=bank_id, content_type=content_type, signal_type=signal_type, signal_value=signal_value, )
def add_bank_member_signal( banks_table: BanksTable, bank_id: str, bank_member_id: str, signal_type: t.Type[SignalType], signal_value: str, ) -> BankMemberSignal: """ Add a bank member signal. Will deduplicate a signal_value + signal_type tuple before writing to the database. Calling this API also makes the signal (new or existing) available to process into matching indices. """ return banks_table.add_bank_member_signal( bank_id=bank_id, bank_member_id=bank_member_id, signal_type=signal_type, signal_value=signal_value, )
def add_detached_bank_member_signal_batch( banks_table: BanksTable, bank_id: str, signals: t.Iterable[Signal], ) -> t.Iterable[BankMemberSignal]: """ Dump multiple detached signals into a bank. Check add_detached_bank_member_signal for more details. TODO: At this point, is dumb. Does not actually batch the requests, instead loops through signals and calls single APIs. """ return list( map( lambda signal: banks_table.add_detached_bank_member_signal( bank_id=bank_id, content_type=signal.content_type, signal_type=signal.signal_type, signal_value=signal.signal_value, ), signals, ))
def add_bank_member( banks_table: BanksTable, sqs_client: SQSClient, submissions_queue_url: str, bank_id: str, content_type: t.Type[ContentType], storage_bucket: t.Optional[str], storage_key: t.Optional[str], raw_content: t.Optional[str], notes: str, bank_member_tags: t.Set[str], ) -> BankMember: """ Write bank-member to database. Send a message to hashing lambda to extract signals. """ member = banks_table.add_bank_member( bank_id=bank_id, content_type=content_type, storage_bucket=storage_bucket, storage_key=storage_key, raw_content=raw_content, notes=notes, bank_member_tags=bank_member_tags, ) submission_message = BankSubmissionMessage( content_type=content_type, url=create_presigned_url(storage_bucket, storage_key, None, 3600, "get_object"), bank_id=bank_id, bank_member_id=member.bank_member_id, ) sqs_client.send_message( QueueUrl=submissions_queue_url, MessageBody=json.dumps(submission_message.to_sqs_message()), ) return member
def get_all_bank_hash_rows(signal_type: t.Type[SignalType], banks_table: BanksTable) -> t.Iterable[HashRowT]: """ Make repeated calls to banks table to get all hashes for a signal type. Returns list[HashRowT]. HashRowT is a tuple of hash_value and some metadata about the signal. """ exclusive_start_key = None hash_rows: t.List[HashRowT] = [] while True: page = banks_table.get_bank_member_signals_to_process_page( signal_type=signal_type, exclusive_start_key=exclusive_start_key) for bank_member_signal in page.items: hash_rows.append(( bank_member_signal.signal_value, [ BankedSignalIndexMetadata( bank_member_signal.signal_id, bank_member_signal.signal_value, bank_member_signal.bank_member_id, ), ], )) exclusive_start_key = page.last_evaluated_key if not page.has_next_page(): break logger.info( f"Obtained {len(hash_rows)} hash records from banks for signal_type:{signal_type.get_name()}" ) return hash_rows
class MatchFiltersTestCase(BanksTableTestBase, unittest.TestCase): # NOTE: Table is defined in base class BanksTableTestBase def _create_banks(self): self.table_manager = BanksTable( self.get_table(), get_default_signal_type_mapping() ) self.active_bank = self.table_manager.create_bank("TEST_BANK", "Is Active") self.active_bank_member = self.table_manager.add_bank_member( bank_id=self.active_bank.bank_id, content_type=PhotoContent, raw_content=None, storage_bucket=None, storage_key=None, notes=None, ) self.table_manager.update_bank( bank_id=self.active_bank.bank_id, bank_name=self.active_bank.bank_name, bank_description=self.active_bank.bank_description, is_active=True, ) self.inactive_bank = self.table_manager.create_bank( "TEST_BANK_2", "Is Inactive" ) self.table_manager.update_bank( bank_id=self.inactive_bank.bank_id, bank_name=self.inactive_bank.bank_name, bank_description=self.inactive_bank.bank_description, is_active=False, ) self.inactive_bank_member = self.table_manager.add_bank_member( bank_id=self.inactive_bank.bank_id, content_type=PhotoContent, raw_content=None, storage_bucket=None, storage_key=None, notes=None, ) def _create_privacy_groups(self): # Since we already have a mock_dynamodb2 courtesy BanksTableTestBase, # re-use it for initing configs. Requires some clever hot-wiring. config_test_mock = config_test.ConfigTest() config_test_mock.mock_dynamodb2 = self.__class__.mock_dynamodb2 config_test_mock.create_mocked_table() HMAConfig.initialize(config_test_mock.TABLE_NAME) # Hot wiring ends... self.active_pg = ThreatExchangeConfig( "ACTIVE_PG", True, "", True, True, True, "ACTIVE_PG" ) create_config(self.active_pg) # Active PG has a distance threshold of 31. create_config(AdditionalMatchSettingsConfig("ACTIVE_PG", 31)) self.inactive_pg = ThreatExchangeConfig( "INACTIVE_PG", True, "", True, True, False, "INACTIVE_PG" ) create_config(self.inactive_pg) def _init_data_if_required(self): self._create_banks() self._create_privacy_groups() def _active_pg_match(self): return IndexMatch( 0, [ ThreatExchangeIndicatorIndexMetadata( "indicator_id", "hash_value", self.active_pg.privacy_group_id, ) ], ) def _inactive_pg_match(self): return IndexMatch( 0, [ ThreatExchangeIndicatorIndexMetadata( "indicator_id", "hash_value", self.inactive_pg.privacy_group_id, ) ], ) def _active_bank_match(self): return IndexMatch( 0, [ BankedSignalIndexMetadata( "signal", "signal_value", self.active_bank_member.bank_member_id ) ], ) def _inactive_bank_match(self): return IndexMatch( 0, [ BankedSignalIndexMetadata( "signal", "signal_value", self.inactive_bank_member.bank_member_id ) ], ) def test_matcher_filters_out_inactive_pg(self): with self.fresh_dynamodb(): self._init_data_if_required() matcher = Matcher("", [PdqSignal, VideoMD5Signal], self.table_manager) filtered_matches = matcher.filter_match_results( [self._active_pg_match(), self._inactive_pg_match()], PdqSignal, ) self.assertEqual( len(filtered_matches), 1, "Failed to filter out inactive pg match" ) self.assertEqual( filtered_matches[0].metadata[0].privacy_group, self.active_pg.privacy_group_id, "The filtered privacy group id is wrong. It should be the active pg's id.", ) def test_matcher_filters_out_based_on_distance(self): with self.fresh_dynamodb(): self._init_data_if_required() match_1 = self._active_pg_match() match_2 = self._active_pg_match() match_2.distance = 100 matcher = Matcher("", [PdqSignal, VideoMD5Signal], self.table_manager) filtered_matches = matcher.filter_match_results( [match_1, match_2], PdqSignal ) self.assertEqual( len(filtered_matches), 1, "Failed to filter out match with distance > threshold", ) self.assertEqual( filtered_matches[0].distance, 0, "Filtered out the wrong match. Match with distance = 100 should be filtered out.", ) def test_matcher_filters_out_based_on_bank_active(self): with self.fresh_dynamodb(): self._init_data_if_required() matcher = Matcher("", [PdqSignal, VideoMD5Signal], self.table_manager) filtered_matches = matcher.filter_match_results( [self._active_bank_match(), self._inactive_bank_match()], PdqSignal, ) self.assertEqual( len(filtered_matches), 1, "Failed to filter out inactive bank's match" ) self.assertEqual( filtered_matches[0].metadata[0].bank_member_id, self.active_bank_member.bank_member_id, "The filtered bank_member id is wrong. It should be the active bank's bank_member's id.", )
def test_bank_member_removes_from_get_members_page(self): NUM_MEMBERS = 100 REMOVE_EVERY_XTH_MEMBER = 4 with self.fresh_dynamodb(): table_manager = BanksTable(self.get_table()) bank_id, bank_member_id = self._create_bank_and_bank_member() for i in range(NUM_MEMBERS): bank_member = table_manager.add_bank_member( bank_id=bank_id, content_type=VideoContent, raw_content=None, storage_bucket="hma-test-media", storage_key="irrrelevant", notes="", ) members = [] exclusive_start_key = None while True: page = table_manager.get_all_bank_members_page( bank_id=bank_id, content_type=VideoContent, exclusive_start_key=exclusive_start_key, ) members += page.items exclusive_start_key = page.last_evaluated_key if not page.has_next_page(): break self.assertEqual( len(members), 101, "All the pages together have as many members as we added.", ) count_members_removed = 0 for i, member in enumerate(members): if i // REMOVE_EVERY_XTH_MEMBER == 0: table_manager.remove_bank_member(member.bank_member_id) count_members_removed += 1 members = [] exclusive_start_key = None while True: page = table_manager.get_all_bank_members_page( bank_id=bank_id, content_type=VideoContent, exclusive_start_key=exclusive_start_key, ) members += page.items exclusive_start_key = page.last_evaluated_key if not page.has_next_page(): break self.assertEqual( len(members), 101 - count_members_removed, "All the pages together have as many members as we added minus the ones we removed.", )
def get_matches_api( datastore_table: Table, hma_config_table: str, indexes_bucket_name: str, writeback_queue_url: str, bank_table: Table, signal_type_mapping: HMASignalTypeMapping, ) -> bottle.Bottle: """ A Closure that includes all dependencies that MUST be provided by the root API that this API plugs into. Declare dependencies here, but initialize in the root API alone. """ # A prefix to all routes must be provided by the api_root app # The documentation below expects prefix to be '/matches/' matches_api = SubApp() HMAConfig.initialize(hma_config_table) banks_table = BanksTable(table=bank_table, signal_type_mapping=signal_type_mapping) @matches_api.get("/", apply=[jsoninator]) def matches() -> MatchSummariesResponse: """ Return all, or a filtered list of matches based on query params. """ signal_q = bottle.request.query.signal_q or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)`` signal_source = bottle.request.query.signal_source or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)`` content_q = bottle.request.query.content_q or None # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)`` if content_q: records = MatchRecord.get_from_content_id(datastore_table, content_q, signal_type_mapping) elif signal_q: records = MatchRecord.get_from_signal(datastore_table, signal_q, signal_source or "", signal_type_mapping) else: # TODO: Support pagination after implementing in UI. records = MatchRecord.get_recent_items_page( datastore_table, signal_type_mapping).items return MatchSummariesResponse(match_summaries=[ MatchSummary( content_id=record.content_id, signal_id=record.signal_id, signal_source=record.signal_source, updated_at=record.updated_at.isoformat(), ) for record in records ]) @matches_api.get("/match/", apply=[jsoninator]) def match_details() -> MatchDetailsResponse: """ Return the match details for a given content id. """ results = [] if content_id := bottle.request.query.content_id or None: # type: ignore # ToDo refactor to use `jsoninator(<requestObj>, from_query=True)`` results = get_match_details( datastore_table=datastore_table, banks_table=banks_table, content_id=content_id, signal_type_mapping=signal_type_mapping, ) return MatchDetailsResponse(match_details=results)
def get_bank_api(bank_table: Table, bank_user_media_bucket: str, submissions_queue_url: str) -> bottle.Bottle: """ Closure for dependencies of the bank API """ bank_api = SubApp() table_manager = BanksTable(table=bank_table) # Bank Management @bank_api.get("/get-all-banks", apply=[jsoninator]) def get_all_banks() -> AllBanksEnvelope: """ Get all banks. """ return AllBanksEnvelope(banks=table_manager.get_all_banks()) @bank_api.get("/get-bank/<bank_id>", apply=[jsoninator]) def get_bank(bank_id=None) -> Bank: """ Get a specific bank from a bank_id. """ bank = table_manager.get_bank(bank_id=bank_id) return bank @bank_api.post("/create-bank", apply=[jsoninator]) def create_bank() -> Bank: """ Create a bank using only the name and description. """ return table_manager.create_bank( bank_name=bottle.request.json["bank_name"], bank_description=bottle.request.json["bank_description"], ) @bank_api.post("/update-bank/<bank_id>", apply=[jsoninator]) def update_bank(bank_id=None) -> Bank: """ Update name and description for a bank_id. """ return table_manager.update_bank( bank_id=bank_id, bank_name=bottle.request.json["bank_name"], bank_description=bottle.request.json["bank_description"], ) # Member Management @bank_api.get("/get-members/<bank_id>", apply=[jsoninator]) def get_members(bank_id=None) -> BankMembersPage: """ Get a page of bank members. Use the "continuation_token" from this response to get subsequent pages. """ continuation_token = ( bottle.request.query.continuation_token and json.loads(bottle.request.query.continuation_token) or None) try: content_type = get_content_type_for_name( bottle.request.query.content_type) except: bottle.abort( 400, "content_type must be provided as a query parameter.") db_response = table_manager.get_all_bank_members_page( bank_id=bank_id, content_type=content_type, exclusive_start_key=continuation_token, ) continuation_token = None if db_response.last_evaluated_key: continuation_token = uriencode( json.dumps(db_response.last_evaluated_key)) return BankMembersPage( bank_members=with_preview_urls(db_response.items), continuation_token=continuation_token, ) @bank_api.post("/add-member/<bank_id>", apply=[jsoninator]) def add_member(bank_id=None) -> PreviewableBankMember: """ Add a bank member. Expects a JSON object with following fields: - content_type: ["photo"|"video"] - storage_bucket: s3bucket for the media - storage_key: key for the media on s3 - notes: String, any additional notes you want to associate with this member. Clients would want to use get_media_upload_url() to get a storage_bucket, storage_key and a upload_url before using add_member() Returns 200 OK with the resulting bank_member. 500 on failure. """ content_type = get_content_type_for_name( bottle.request.json["content_type"]) storage_bucket = bottle.request.json["storage_bucket"] storage_key = bottle.request.json["storage_key"] notes = bottle.request.json["notes"] return with_preview_url( bank_ops.add_bank_member( banks_table=table_manager, sqs_client=_get_sqs_client(), submissions_queue_url=submissions_queue_url, bank_id=bank_id, content_type=content_type, storage_bucket=storage_bucket, storage_key=storage_key, raw_content=None, notes=notes, )) @bank_api.post("/add-detached-member-signal/<bank_id>", apply=[jsoninator]) def add_detached_bank_member_signal(bank_id=None) -> BankMemberSignal: """ Add a virtual bank_member (without any associated media) and a corresponding signal. Requires JSON object with following fields: - signal_type: ["pdq"|"pdq_ocr","photo_md5"] -> anything from threatexchange.content_type.meta.get_signal_types_by_name()'s keys - content_type: ["photo"|"video"] to get the content_type for the virtual member. - signal_value: the hash to store against this signal. Will automatically de-dupe against existing signals. """ content_type = get_content_type_for_name( bottle.request.json["content_type"]) signal_type = get_signal_types_by_name()[ bottle.request.json["signal_type"]] signal_value = bottle.request.json["signal_value"] return bank_ops.add_detached_bank_member_signal( banks_table=table_manager, bank_id=bank_id, content_type=content_type, signal_type=signal_type, signal_value=signal_value, ) # Miscellaneous @bank_api.post("/get-media-upload-url") def get_media_upload_url(media_type=None): """ Get a presigned S3 url that can be used by the client to PUT an object. Request Payload must be json with the following attributes: `media_type` must be something like ['image/gif', 'image/png', 'application/zip'] `extension` must be a period followed by file extension. eg. `.mp4`, `.jpg` """ extension = bottle.request.json.get("extension") media_type = bottle.request.json.get("media_type") if (not extension) or extension[0] != ".": bottle.abort(400, "extension must start with a period. eg. '.mp4'") id = str(uuid.uuid4()) today_fragment = datetime.now().isoformat("|").split("|")[ 0] # eg. 2019-09-12 s3_key = f"bank-media/{media_type}/{today_fragment}/{id}{extension}" return { "storage_bucket": bank_user_media_bucket, "storage_key": s3_key, "upload_url": create_presigned_put_url( bucket_name=bank_user_media_bucket, key=s3_key, file_type=media_type, expiration=3600, ), } @bank_api.get("/get-member/<bank_member_id>", apply=[jsoninator]) def get_member(bank_member_id=None) -> PreviewableBankMemberWithSignals: """ Get a bank member with signals... """ member = table_manager.get_bank_member(bank_member_id=bank_member_id) signals = table_manager.get_signals_for_bank_member( bank_member_id=bank_member_id) return PreviewableBankMemberWithSignals(**asdict( with_preview_url(member)), signals=signals) @bank_api.post("/remove-bank-member/<bank_member_id>") def remove_bank_member(bank_member_id: str): """ Remove bank member signals from the processing index and mark bank_member as is_removed=True. Returns empty json object. """ bank_ops.remove_bank_member( banks_table=table_manager, bank_member_id=bank_member_id, ) return {} return bank_api
def lambda_handler(event, context): """ SQS Events generated by the submissions API or by files being added to S3. Downloads files to temp-storage, identifies content_type and generates allowed signal_types from it. Saves hash output to DynamoDB, sends a message on an output queue. Note that this brings the contents of a file into memory. This is subject to the resource limitation on the lambda. Potentially extendable until 10GB, but that would be super-expensive. [1] [1]: https://docs.aws.amazon.com/lambda/latest/dg/configuration-console.html """ records_table = get_dynamodb().Table(DYNAMODB_TABLE) HMAConfig.initialize(HMA_CONFIG_TABLE) banks_table = BanksTable( get_dynamodb().Table(BANKS_TABLE), _get_signal_type_mapping(), ) sqs_client = get_sqs_client() hasher = _get_hasher(_get_signal_type_mapping()) for sqs_record in event["Records"]: message = json.loads(sqs_record["body"]) if message.get("Event") == "s3:TestEvent": continue media_to_process: t.List[t.Union[S3ImageSubmission, URLSubmissionMessage, BankSubmissionMessage]] = [] if URLSubmissionMessage.could_be(message): media_to_process.append( URLSubmissionMessage.from_sqs_message( message, _get_signal_type_mapping())) elif S3ImageSubmissionBatchMessage.could_be(message): # S3 submissions can only be images for now. media_to_process.extend( S3ImageSubmissionBatchMessage.from_sqs_message( message, image_prefix=IMAGE_PREFIX).image_submissions) elif BankSubmissionMessage.could_be(message): media_to_process.append( BankSubmissionMessage.from_sqs_message( message, _get_signal_type_mapping())) else: logger.warn(f"Unprocessable Message: {message}") for media in media_to_process: if not hasher.supports(media.content_type): if isinstance(media, BankSubmissionMessage): object_id = media.bank_id else: object_id = media.content_id logger.warn( f"Unprocessable content type: {media.content_type}, id: {object_id}" ) continue with metrics.timer(metrics.names.hasher.download_file): try: if hasattr(media, "key") and hasattr(media, "bucket"): # Classic duck-typing. If it has key and bucket, must be an # S3 submission. media = t.cast(S3ImageSubmission, media) bytes_: bytes = S3BucketContentSource( media.bucket, IMAGE_PREFIX).get_bytes(media.content_id) else: media = t.cast(URLSubmissionMessage, media) bytes_: bytes = URLContentSource().get_bytes(media.url) except Exception: if isinstance(media, BankSubmissionMessage): object_id = media.bank_id else: object_id = media.content_id logger.exception( f"Encountered exception while trying to get_bytes for id: {object_id}. Unable to hash content." ) continue for signal in hasher.get_hashes(media.content_type, bytes_): if isinstance(media, BankSubmissionMessage): # route signals to bank datastore only. bank_operations.add_bank_member_signal( banks_table=banks_table, bank_id=media.bank_id, bank_member_id=media.bank_member_id, signal_type=signal.signal_type, signal_value=signal.signal_value, ) # don't write hash records etc. continue hash_record = PipelineHashRecord( content_id=media.content_id, signal_type=signal.signal_type, content_hash=signal.signal_value, updated_at=datetime.datetime.now(), ) hasher.write_hash_record(records_table, hash_record) hasher.publish_hash_message(sqs_client, hash_record) metrics.flush()
import os from hmalib.common.models.models_base import DynamoDBItem from hmalib.common.models.bank import BanksTable, BankMember from threatexchange.content_type.video import VideoContent from mypy_boto3_dynamodb.service_resource import Table import boto3 dynamodb = boto3.resource("dynamodb") table_name = "" test_bank_name = "" num_members = 1000 # must add thes values assert table_name != "" assert test_bank_name != "" table = dynamodb.Table(table_name) table_manager = BanksTable(table) bank = table_manager.create_bank(test_bank_name, "test bank description") for _ in range(num_members): table_manager.add_bank_member( bank_id=bank.bank_id, content_type=VideoContent, raw_content=None, storage_bucket="hma-test-media", storage_key="videos/breaking-news.mp4", notes="", )
def lambda_handler(event, context): """ Listens to SQS events fired when new hash is generated. Loads the index stored in an S3 bucket and looks for a match. When matched, publishes a notification to an SNS endpoint. Note this is in contrast with hasher and indexer. They publish to SQS directly. Publishing to SQS implies there can be only one consumer. Because, here, in the matcher, we publish to SNS, we can plug multiple queues behind it and profit! """ table = get_dynamodb().Table(DYNAMODB_TABLE) banks_table = BanksTable(get_dynamodb().Table(BANKS_TABLE)) for sqs_record in event["Records"]: message = json.loads(sqs_record["body"]) if message.get("Event") == "TestEvent": logger.debug("Disregarding Test Event") continue if not PipelineHashRecord.could_be(message): logger.warn( "Could not de-serialize message in matcher lambda. Message was %s", message, ) continue hash_record = PipelineHashRecord.from_sqs_message(message) logger.info( "HashRecord for contentId: %s with contentHash: %s", hash_record.content_id, hash_record.content_hash, ) matches = get_matcher(banks_table).match(hash_record.signal_type, hash_record.content_hash) logger.info("Found %d matches.", len(matches)) for match in matches: get_matcher(banks_table).write_match_record_for_result( table=table, signal_type=hash_record.signal_type, content_hash=hash_record.content_hash, content_id=hash_record.content_id, match=match, ) for match in matches: get_matcher(banks_table).write_signal_if_not_found( table=table, signal_type=hash_record.signal_type, match=match) if len(matches) != 0: # Publish all messages together get_matcher(banks_table).publish_match_message( content_id=hash_record.content_id, content_hash=hash_record.content_hash, matches=matches, sns_client=get_sns_client(), topic_arn=MATCHES_TOPIC_ARN, ) metrics.flush()