def test_watcher(redis_connection): redis_connection.time = RedisTime() rds = redis_connection queue_name = get_random_id() out_queue = NamedQueue(queue_name, rds) try: # Create a server and hijack its running flag and the current time in 'redis' client = WatcherClient(rds) server = WatcherServer(rds, rds) server.running = ToggleTrue() rds.time.current = 0 assert out_queue.length() == 0 # Send a simple event to occur soon client.touch(10, 'one-second', queue_name, {'first': 'one'}) server.try_run() assert out_queue.length() == 0 # Nothing yet rds.time.current = 12 # Jump forward 12 seconds server.try_run() assert out_queue.length() == 1 assert out_queue.pop() == {'first': 'one'} # Send a simple event to occur soon, then change our mind client.touch(10, 'one-second', queue_name, {'first': 'one'}) client.touch(20, 'one-second', queue_name, {'first': 'one'}) server.try_run() assert out_queue.length() == 0 # Nothing yet # Set events to occur, in inverse order, reuse a key, overwrite content and timeout client.touch(200, 'one-second', queue_name, {'first': 'last'}) client.touch(100, '100-second', queue_name, {'first': '100'}) client.touch(50, '50-second', queue_name, {'first': '50'}) server.try_run() assert out_queue.length() == 0 # Nothing yet for _ in range(15): rds.time.current += 20 server.try_run() assert out_queue.length() == 3 assert out_queue.pop() == {'first': '50'} assert out_queue.pop() == {'first': '100'} assert out_queue.pop() == {'first': 'last'} # Send a simple event to occur soon, then stop it rds.time.current = 0 client.touch(10, 'one-second', queue_name, {'first': 'one'}) server.try_run() assert out_queue.length() == 0 # Nothing yet client.clear('one-second') rds.time.current = 12 # Jump forward 12 seconds server.try_run() assert out_queue.length() == 0 # still nothing because it was cleared finally: out_queue.delete()
def test_service_retry_limit(core, metrics): # This time have the service 'crash' sha, size = ready_body(core, {'pre': {'drop': 3}}) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='watcher-recover', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-watcher-recover', core.redis) dropped_task = notification_queue.pop(timeout=RESPONSE_TIMEOUT) assert dropped_task dropped_task = IngestTask(dropped_task) sub = core.ds.submission.get(dropped_task.submission.sid) assert len(sub.errors) == 1 assert len(sub.results) == 3 assert core.pre_service.drops[sha] == 3 assert core.pre_service.hits[sha] == 3 # Wait until we get feedback from the metrics channel metrics.expect('ingester', 'submissions_ingested', 1) metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('dispatcher', 'service_timeouts', 3) metrics.expect('service', 'fail_recoverable', 3) metrics.expect('service', 'fail_nonrecoverable', 1) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 1)
def test_ingest_retry(core: CoreSession, metrics): # ------------------------------------------------------------------------------- # sha, size = ready_body(core) original_retry_delay = assemblyline_core.ingester.ingester._retry_delay assemblyline_core.ingester.ingester._retry_delay = 1 attempts = [] failures = [] original_submit = core.ingest.submit def fail_once(task): attempts.append(task) if len(attempts) > 1: original_submit(task) else: failures.append(task) raise ValueError() core.ingest.submit = fail_once try: core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict( description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], ), notification=dict(queue='output-queue-one', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-output-queue-one', core.redis) first_task = notification_queue.pop(timeout=RESPONSE_TIMEOUT) # One of the submission will get processed fully assert first_task is not None first_task = IngestTask(first_task) first_submission: Submission = core.ds.submission.get( first_task.submission.sid) assert len(attempts) == 2 assert len(failures) == 1 assert first_submission.state == 'completed' assert len(first_submission.files) == 1 assert len(first_submission.errors) == 0 assert len(first_submission.results) == 4 metrics.expect('ingester', 'submissions_ingested', 1) metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('ingester', 'files_completed', 1) metrics.expect('ingester', 'duplicates', 0) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 1) finally: core.ingest.submit = original_submit assemblyline_core.ingester.ingester._retry_delay = original_retry_delay
def test_dropping_early(core, metrics): # ------------------------------------------------------------------------------- # This time have a file get marked for dropping by a service sha, size = ready_body(core, {'pre': {'result': {'drop_file': True}}}) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='drop', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-drop', core.redis) dropped_task = notification_queue.pop(timeout=RESPONSE_TIMEOUT) dropped_task = IngestTask(dropped_task) sub = core.ds.submission.get(dropped_task.submission.sid) assert len(sub.files) == 1 assert len(sub.results) == 1 metrics.expect('ingester', 'submissions_ingested', 1) metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 1)
def test_max_extracted_in_several(core): # Make a set of in a non trivial tree, that add up to more than 3 (max_extracted) files children = [ ready_extract( core, [ready_body(core)[0], ready_body(core)[0]])[0], ready_extract( core, [ready_body(core)[0], ready_body(core)[0]])[0] ] sha, size = ready_extract(core, children) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=3), notification=dict(queue='test-extracted-in-several', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-test-extracted-in-several', core.redis) task = IngestTask(notification_queue.pop(timeout=10)) sub: Submission = core.ds.submission.get(task.submission.sid) assert len(sub.files) == 1 # We should only get results for each file up to the max depth assert len(sub.results) == 4 * ( 1 + 3) # 4 services, 1 original file, 3 extracted files assert len(sub.errors) == 3 # The number of children that errored out
def test_depth_limit(core): # Make a nested set of files that goes deeper than the max depth by one sha, size = ready_body(core) for _ in range(core.config.submission.max_extraction_depth + 1): sha, size = ready_extract(core, sha) core.ingest_queue.push( SubmissionInput( dict( metadata={}, params=dict( description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], # Make sure we can extract enough files that we will definitely hit the depth limit first max_extracted=core.config.submission.max_extraction_depth + 10), notification=dict(queue='test-depth-limit', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-test-depth-limit', core.redis) start = time.time() task = notification_queue.pop(timeout=10) print("notification time waited", time.time() - start) assert task is not None task = IngestTask(task) sub: Submission = core.ds.submission.get(task.submission.sid) assert len(sub.files) == 1 # We should only get results for each file up to the max depth assert len(sub.results) == 4 * core.config.submission.max_extraction_depth assert len(sub.errors) == 1
def test_max_extracted_in_one(core): # Make a set of files that is bigger than max_extracted (3 in this case) children = [ready_body(core)[0] for _ in range(5)] sha, size = ready_extract(core, children) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=3), notification=dict(queue='test-extracted-in-one', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-test-extracted-in-one', core.redis) start = time.time() task = notification_queue.pop(timeout=10) print("notification time waited", time.time() - start) assert task is not None task = IngestTask(task) sub: Submission = core.ds.submission.get(task.submission.sid) assert len(sub.files) == 1 # We should only get results for each file up to the max depth assert len(sub.results) == 4 * (1 + 3) assert len(sub.errors) == 2 # The number of children that errored out
def run_once(): counter.reset_mock() core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict( description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], ), notification=dict(queue='1', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-1', core.redis) first_task = notification_queue.pop(timeout=5) # One of the submission will get processed fully assert first_task is not None first_task = IngestTask(first_task) first_submission: Submission = core.ds.submission.get( first_task.submission.sid) assert first_submission.state == 'completed' assert len(first_submission.files) == 1 assert len(first_submission.errors) == 0 assert len(first_submission.results) == 4 return first_submission.sid
def test_service_retry_limit(core): watch = WatcherServer(redis=core.redis, redis_persist=core.redis) watch.start() try: # This time have the service 'crash' sha, size = ready_body(core, {'pre': {'drop': 3}}) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='watcher-recover', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-watcher-recover', core.redis) dropped_task = notification_queue.pop(timeout=16) assert dropped_task dropped_task = IngestTask(dropped_task) sub = core.ds.submission.get(dropped_task.submission.sid) assert len(sub.errors) == 1 assert len(sub.results) == 3 assert core.pre_service.drops[sha] == 3 assert core.pre_service.hits[sha] == 3 finally: watch.stop() watch.join()
def get_all_messages(notification_queue, **kwargs): """ Get all messages on the specified notification queue Variables: complete_queue => Queue to get the message from Arguments: None Data Block: None Result example: [] # List of messages """ resp_list = [] u = NamedQueue("nq-%s" % notification_queue, host=config.core.redis.persistent.host, port=config.core.redis.persistent.port) while True: msg = u.pop(blocking=False) if msg is None: break resp_list.append(msg) return make_api_response(resp_list)
def test_extracted_file(core, metrics): sha, size = ready_extract(core, ready_body(core)[0]) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='text-extracted-file', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-text-extracted-file', core.redis) task = notification_queue.pop(timeout=RESPONSE_TIMEOUT) assert task task = IngestTask(task) sub = core.ds.submission.get(task.submission.sid) assert len(sub.files) == 1 assert len(sub.results) == 8 assert len(sub.errors) == 0 metrics.expect('ingester', 'submissions_ingested', 1) metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 2)
def backup_worker(worker_id, instance_id, working_dir): datastore = forge.get_datastore(archive_access=True) worker_queue = NamedQueue(f"r-worker-{instance_id}", ttl=1800) done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800) hash_queue = Hash(f"r-hash-{instance_id}") stopping = False with open(os.path.join(working_dir, "backup.part%s" % worker_id), "w+") as backup_file: while True: data = worker_queue.pop(timeout=1) if data is None: if stopping: break continue if data.get('stop', False): if not stopping: stopping = True else: time.sleep(round(random.uniform(0.050, 0.250), 3)) worker_queue.push(data) continue missing = False success = True try: to_write = datastore.get_collection(data['bucket_name']).get( data['key'], as_obj=False) if to_write: if data.get('follow_keys', False): for bucket, bucket_key, getter in FOLLOW_KEYS.get( data['bucket_name'], []): for key in getter(to_write.get(bucket_key, None)): hash_key = "%s_%s" % (bucket, key) if not hash_queue.exists(hash_key): hash_queue.add(hash_key, "True") worker_queue.push({ "bucket_name": bucket, "key": key, "follow_keys": True }) backup_file.write( json.dumps((data['bucket_name'], data['key'], to_write)) + "\n") else: missing = True except Exception: success = False done_queue.push({ "success": success, "missing": missing, "bucket_name": data['bucket_name'], "key": data['key'] }) done_queue.push({"stopped": True})
def test_plumber_clearing(core, metrics): global _global_semaphore _global_semaphore = threading.Semaphore(value=0) start = time.time() try: # Have the plumber cancel tasks sha, size = ready_body(core, {'pre': {'hold': 60}}) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='test_plumber_clearing', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) metrics.expect('ingester', 'submissions_ingested', 1) service_queue = get_service_queue('pre', core.redis) start = time.time() while service_queue.length() < 1: if time.time() - start > RESPONSE_TIMEOUT: pytest.fail(f'Found { service_queue.length()}') time.sleep(0.1) service_delta = core.ds.service_delta.get('pre') service_delta['enabled'] = False core.ds.service_delta.save('pre', service_delta) notification_queue = NamedQueue('nq-test_plumber_clearing', core.redis) dropped_task = notification_queue.pop(timeout=RESPONSE_TIMEOUT) dropped_task = IngestTask(dropped_task) sub = core.ds.submission.get(dropped_task.submission.sid) assert len(sub.files) == 1 assert len(sub.results) == 3 assert len(sub.errors) == 1 error = core.ds.error.get(sub.errors[0]) assert "disabled" in error.response.message metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 1) metrics.expect('service', 'fail_recoverable', 1) finally: _global_semaphore.release() service_delta = core.ds.service_delta.get('pre') service_delta['enabled'] = True core.ds.service_delta.save('pre', service_delta)
def test_plumber_clearing(core): global _global_semaphore _global_semaphore = threading.Semaphore(value=0) start = time.time() watch = WatcherServer(redis=core.redis, redis_persist=core.redis) watch.start() try: # Have the plumber cancel tasks sha, size = ready_body(core, {'pre': {'semaphore': 60}}) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='test_plumber_clearing', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) service_queue = get_service_queue('pre', core.redis) time.sleep(0.5) while service_queue.length() == 0 and time.time() - start < 20: time.sleep(0.1) service_delta = core.ds.service_delta.get('pre') service_delta['enabled'] = False core.ds.service_delta.save('pre', service_delta) notification_queue = NamedQueue('nq-test_plumber_clearing', core.redis) dropped_task = notification_queue.pop(timeout=5) dropped_task = IngestTask(dropped_task) sub = core.ds.submission.get(dropped_task.submission.sid) assert len(sub.files) == 1 assert len(sub.results) == 3 assert len(sub.errors) == 1 error = core.ds.error.get(sub.errors[0]) assert "disabled" in error.response.message finally: _global_semaphore.release() service_delta = core.ds.service_delta.get('pre') service_delta['enabled'] = True core.ds.service_delta.save('pre', service_delta) watch.stop() watch.join()
def test_service_error(core, metrics): # ------------------------------------------------------------------------------- # Have a service produce an error # ------------------------------------------------------------------------------- # This time have a file get marked for dropping by a service sha, size = ready_body( core, { 'core-a': { 'error': { 'archive_ts': time.time() + 250, 'sha256': 'a' * 64, 'response': { 'message': 'words', 'status': 'FAIL_NONRECOVERABLE', 'service_name': 'core-a', 'service_tool_version': 0, 'service_version': '0' }, 'expiry_ts': time.time() + 500 }, 'failure': True, } }) core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='error', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-error', core.redis) task = IngestTask(notification_queue.pop(timeout=RESPONSE_TIMEOUT)) sub = core.ds.submission.get(task.submission.sid) assert len(sub.files) == 1 assert len(sub.results) == 3 assert len(sub.errors) == 1 metrics.expect('ingester', 'submissions_ingested', 1) metrics.expect('ingester', 'submissions_completed', 1) metrics.expect('dispatcher', 'submissions_completed', 1) metrics.expect('dispatcher', 'files_completed', 1)
def get_messages(wq_id, **_): """ Get all messages currently on a watch queue. Note: This method is not optimal because it requires the UI to pull the information. The prefered method is the socket server when possible. Variables: wq_id => Queue to get the message from Arguments: None Data Block: None Result example: [] # List of messages """ resp_list = [] u = NamedQueue(wq_id) while True: msg = u.pop(blocking=False) if msg is None: break elif msg['status'] == 'STOP': response = {'type': 'stop', 'err_msg': None, 'status_code': 200, 'msg': "All messages received, closing queue..."} elif msg['status'] == 'START': response = {'type': 'start', 'err_msg': None, 'status_code': 200, 'msg': "Start listening..."} elif msg['status'] == 'OK': response = {'type': 'cachekey', 'err_msg': None, 'status_code': 200, 'msg': msg['cache_key']} elif msg['status'] == 'FAIL': response = {'type': 'cachekeyerr', 'err_msg': None, 'status_code': 200, 'msg': msg['cache_key']} else: response = {'type': 'error', 'err_msg': "Unknown message", 'status_code': 400, 'msg': msg} resp_list.append(response) return make_api_response(resp_list)
def outstanding_services(self, sid) -> Dict[str, int]: """ List outstanding services for a given submission and the number of file each of them still have to process. :param sid: Submission ID :return: Dictionary of services and number of files remaining per services e.g. {"SERVICE_NAME": 1, ... } """ dispatcher_id = self.submission_assignments.get(sid) if dispatcher_id: queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue") queue = NamedQueue(queue_name, host=self.redis, ttl=30) command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis) command_queue.push(DispatcherCommandMessage({ 'kind': LIST_OUTSTANDING, 'payload_data': ListOutstanding({ 'response_queue': queue_name, 'submission': sid }) }).as_primitives()) return queue.pop(timeout=30) return {}
def get_message(notification_queue, **kwargs): """ Get one message on the specified notification queue Variables: complete_queue => Queue to get the message from Arguments: None Data Block: None Result example: {} # A message """ u = NamedQueue("nq-%s" % notification_queue, host=config.core.redis.persistent.host, port=config.core.redis.persistent.port) msg = u.pop(blocking=False) return make_api_response(msg)
class DirectClient(ClientBase): def __init__(self, log, alert_fqs=None, submission_fqs=None, lookback_time='*'): # Setup datastore config = forge.get_config() redis = get_client(config.core.redis.nonpersistent.host, config.core.redis.nonpersistent.port, False) self.datastore = forge.get_datastore(config=config) self.alert_queue = NamedQueue("replay_alert", host=redis) self.file_queue = NamedQueue("replay_file", host=redis) self.submission_queue = NamedQueue("replay_submission", host=redis) super().__init__(log, alert_fqs=alert_fqs, submission_fqs=submission_fqs, lookback_time=lookback_time) def _get_next_alert_ids(self, query, filter_queries): return self.datastore.alert.search( query, fl="alert_id,reporting_ts", sort="reporting_ts asc", rows=100, as_obj=False, filters=filter_queries) def _get_next_submission_ids(self, query, filter_queries): return self.datastore.submission.search( query, fl="sid,times.completed", sort="times.completed asc", rows=100, as_obj=False, filters=filter_queries) def _set_bulk_alert_pending(self, query, filter_queries, max_docs): operations = [(self.datastore.alert.UPDATE_SET, 'metadata.replay', REPLAY_PENDING)] self.datastore.alert.update_by_query(query, operations, filters=filter_queries, max_docs=max_docs) def _set_bulk_submission_pending(self, query, filter_queries, max_docs): operations = [(self.datastore.submission.UPDATE_SET, 'metadata.replay', REPLAY_PENDING)] self.datastore.submission.update_by_query(query, operations, filters=filter_queries, max_docs=max_docs) def _stream_alert_ids(self, query): return self.datastore.alert.stream_search(query, fl="alert_id,reporting_ts", as_obj=False) def _stream_submission_ids(self, query): return self.datastore.submission.stream_search(query, fl="sid,times.completed", as_obj=False) def create_alert_bundle(self, alert_id, bundle_path): temp_bundle_file = create_bundle(alert_id, working_dir=os.path.dirname(bundle_path), use_alert=True) os.rename(temp_bundle_file, bundle_path) def create_submission_bundle(self, sid, bundle_path): temp_bundle_file = create_bundle(sid, working_dir=os.path.dirname(bundle_path)) os.rename(temp_bundle_file, bundle_path) def load_bundle(self, bundle_path, min_classification, rescan_services, exist_ok=True): import_bundle(bundle_path, min_classification=min_classification, rescan_services=rescan_services, exist_ok=exist_ok) def set_single_alert_complete(self, alert_id): operations = [(self.datastore.alert.UPDATE_SET, 'metadata.replay', REPLAY_DONE)] self.datastore.alert.update(alert_id, operations) def set_single_submission_complete(self, sid): operations = [(self.datastore.submission.UPDATE_SET, 'metadata.replay', REPLAY_DONE)] self.datastore.submission.update(sid, operations) def get_next_alert(self): return self.alert_queue.pop(blocking=True, timeout=30) def get_next_file(self): return self.file_queue.pop(blocking=True, timeout=30) def get_next_submission(self): return self.submission_queue.pop(blocking=True, timeout=30) def put_alert(self, alert): self.alert_queue.push(alert) def put_file(self, path): self.file_queue.push(path) def put_submission(self, submission): self.submission_queue.push(submission)
class DistributedBackup(object): def __init__(self, working_dir, worker_count=50, spawn_workers=True, use_threading=False, logger=None): self.working_dir = working_dir self.datastore = forge.get_datastore(archive_access=True) self.logger = logger self.plist = [] self.use_threading = use_threading self.instance_id = get_random_id() self.worker_queue = NamedQueue(f"r-worker-{self.instance_id}", ttl=1800) self.done_queue = NamedQueue(f"r-done-{self.instance_id}", ttl=1800) self.hash_queue = Hash(f"r-hash-{self.instance_id}") self.bucket_error = [] self.VALID_BUCKETS = sorted(list( self.datastore.ds.get_models().keys())) self.worker_count = worker_count self.spawn_workers = spawn_workers self.total_count = 0 self.error_map_count = {} self.missing_map_count = {} self.map_count = {} self.last_time = 0 self.last_count = 0 self.error_count = 0 def cleanup(self): self.worker_queue.delete() self.done_queue.delete() self.hash_queue.delete() for p in self.plist: p.terminate() def done_thread(self, title): t0 = time.time() self.last_time = t0 running_threads = self.worker_count while running_threads > 0: msg = self.done_queue.pop(timeout=1) if msg is None: continue if "stopped" in msg: running_threads -= 1 continue bucket_name = msg.get('bucket_name', 'unknown') if msg.get('success', False): self.total_count += 1 if msg.get("missing", False): if bucket_name not in self.missing_map_count: self.missing_map_count[bucket_name] = 0 self.missing_map_count[bucket_name] += 1 else: if bucket_name not in self.map_count: self.map_count[bucket_name] = 0 self.map_count[bucket_name] += 1 new_t = time.time() if (new_t - self.last_time) > 5: if self.logger: self.logger.info( "%s (%s at %s keys/sec) ==> %s" % (self.total_count, new_t - self.last_time, int((self.total_count - self.last_count) / (new_t - self.last_time)), self.map_count)) self.last_count = self.total_count self.last_time = new_t else: self.error_count += 1 if bucket_name not in self.error_map_count: self.error_map_count[bucket_name] = 0 self.error_map_count[bucket_name] += 1 # Cleanup self.cleanup() summary = "" summary += "\n########################\n" summary += "####### SUMMARY #######\n" summary += "########################\n" summary += "%s items - %s errors - %s secs\n\n" % \ (self.total_count, self.error_count, time.time() - t0) for k, v in self.map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.missing_map_count.keys()) > 0: summary += "\n\nMissing data:\n\n" for k, v in self.missing_map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.error_map_count.keys()) > 0: summary += "\n\nErrors:\n\n" for k, v in self.error_map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.bucket_error) > 0: summary += f"\nThese buckets failed to {title.lower()} completely: {self.bucket_error}\n" if self.logger: self.logger.info(summary) # noinspection PyBroadException,PyProtectedMember def backup(self, bucket_list, follow_keys=False, query=None): if query is None: query = 'id:*' for bucket in bucket_list: if bucket not in self.VALID_BUCKETS: if self.logger: self.logger.warn( "\n%s is not a valid bucket.\n\n" "The list of valid buckets is the following:\n\n\t%s\n" % (bucket.upper(), "\n\t".join(self.VALID_BUCKETS))) return targets = ', '.join(bucket_list) try: if self.logger: self.logger.info("\n-----------------------") self.logger.info("----- Data Backup -----") self.logger.info("-----------------------") self.logger.info(f" Deep: {follow_keys}") self.logger.info(f" Buckets: {targets}") self.logger.info(f" Workers: {self.worker_count}") self.logger.info(f" Target directory: {self.working_dir}") self.logger.info(f" Filtering query: {query}") # Start the workers for x in range(self.worker_count): if self.use_threading: t = threading.Thread(target=backup_worker, args=(x, self.instance_id, self.working_dir)) t.setDaemon(True) t.start() else: p = Process(target=backup_worker, args=(x, self.instance_id, self.working_dir)) p.start() self.plist.append(p) # Start done thread dt = threading.Thread(target=self.done_thread, args=('Backup', ), name="Done thread") dt.setDaemon(True) dt.start() # Process data buckets for bucket_name in bucket_list: try: collection = self.datastore.get_collection(bucket_name) for item in collection.stream_search(query, fl="id", item_buffer_size=500, as_obj=False): self.worker_queue.push({ "bucket_name": bucket_name, "key": item['id'], "follow_keys": follow_keys }) except Exception as e: self.cleanup() if self.logger: self.logger.execption(e) self.logger.error( "Error occurred while processing bucket %s." % bucket_name) self.bucket_error.append(bucket_name) for _ in range(self.worker_count): self.worker_queue.push({"stop": True}) dt.join() except Exception as e: if self.logger: self.logger.execption(e) def restore(self): try: if self.logger: self.logger.info("\n------------------------") self.logger.info("----- Data Restore -----") self.logger.info("------------------------") self.logger.info(f" Workers: {self.worker_count}") self.logger.info(f" Target directory: {self.working_dir}") for x in range(self.worker_count): if self.use_threading: t = threading.Thread(target=restore_worker, args=(x, self.instance_id, self.working_dir)) t.setDaemon(True) t.start() else: p = Process(target=restore_worker, args=(x, self.instance_id, self.working_dir)) p.start() self.plist.append(p) # Start done thread dt = threading.Thread(target=self.done_thread, args=('Restore', ), name="Done thread") dt.setDaemon(True) dt.start() # Wait for workers to finish dt.join() except Exception as e: if self.logger: self.logger.execption(e)
class ScalerServer(ThreadedCoreBase): def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): super().__init__('assemblyline.scaler', config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count_lock = threading.Lock() self.error_count: dict[str, list[float]] = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30 * 60) self.service_event_sender = EventSender('changes.services', host=self.redis) self.service_change_watcher = EventWatcher( self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) core_env: dict[str, str] = {} # If we have privileged services, we must be able to pass the necessary environment variables for them to # function properly. for secret in re.findall( r'\${\w+}', open('/etc/assemblyline/config.yml', 'r').read()) + ['UI_SERVER']: env_name = secret.strip("${}") core_env[env_name] = os.environ[env_name] labels = { 'app': 'assemblyline', 'section': 'service', 'privilege': 'service' } if self.config.core.scaler.additional_labels: labels.update({ k: v for k, v in ( _l.split("=") for _l in self.config.core.scaler.additional_labels) }) if KUBERNETES_AL_CONFIG: self.log.info( f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}" ) self.controller = KubernetesController( logger=self.log, prefix='alsvc_', labels=labels, namespace=NAMESPACE, priority='al-service-priority', cpu_reservation=self.config.services.cpu_reservation, log_level=self.config.logging.log_level, core_env=core_env) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_CONFIGMAP: self.controller.config_mount( 'classification-config', config_map=CLASSIFICATION_CONFIGMAP, key=CLASSIFICATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/classification.yml') if CONFIGURATION_CONFIGMAP: self.controller.core_config_mount( 'assemblyline-config', config_map=CONFIGURATION_CONFIGMAP, key=CONFIGURATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/config.yml') else: self.log.info("Loading Docker cluster interface.") self.controller = DockerController( logger=self.log, prefix=NAMESPACE, labels=labels, log_level=self.config.logging.log_level, core_env=core_env) self._service_stage_hash.delete() if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME: self.controller.core_mounts.append( (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/')) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'), 'w') as handle: yaml.dump(self.config.as_primitives(), handle) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'classification.yml'), 'w') as handle: yaml.dump(get_classification().original_definition, handle) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_HOST_PATH: self.controller.global_mounts.append( (CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) # Information about services self.profiles: dict[str, ServiceProfile] = {} self.profiles_lock = threading.RLock() # Prepare a single threaded scheduler self.state = collection.Collection( period=self.config.core.metrics.export_interval) self.stopping = threading.Event() self.main_loop_exit = threading.Event() # Load the APM connection if any self.apm_client = None if self.config.core.metrics.apm_server.server_url: elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="scaler") def log_crashes(self, fn): @functools.wraps(fn) def with_logs(*args, **kwargs): # noinspection PyBroadException try: fn(*args, **kwargs) except ServiceControlError as error: self.log.exception( f"Error while managing service: {error.service_name}") self.handle_service_error(error.service_name) except Exception: self.log.exception(f'Crash in scaler: {fn.__name__}') return with_logs @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def add_service(self, profile: ServiceProfile): # We need to hold the lock the whole time we add the service, # we don't want the scaling thread trying to adjust the scale of a # deployment we haven't added to the system yet with self.profiles_lock: profile.desired_instances = max( self.controller.get_target(profile.name), profile.min_instances) profile.running_instances = profile.desired_instances profile.target_instances = profile.desired_instances self.log.debug( f'Starting service {profile.name} with a target of {profile.desired_instances}' ) profile.last_update = time.time() self.profiles[profile.name] = profile self.controller.add_profile(profile, scale=profile.desired_instances) def try_run(self): self.service_change_watcher.start() self.maintain_threads({ 'Log Container Events': self.log_container_events, 'Process Timeouts': self.process_timeouts, 'Service Configuration Sync': self.sync_services, 'Service Adjuster': self.update_scaling, 'Import Metrics': self.sync_metrics, 'Export Metrics': self.export_metrics, }) def stop(self): super().stop() self.service_change_watcher.stop() self.controller.stop() def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Removed: self.log.info( f'Service appears to be deleted, removing {data.name}') stage = self.get_service_stage(data.name) self.stop_service(data.name, stage) elif data.operation == Operation.Incompatible: return else: self._sync_service(self.datastore.get_service_with_delta( data.name)) def sync_services(self): while self.running: with apm_span(self.apm_client, 'sync_services'): with self.profiles_lock: current_services = set(self.profiles.keys()) discovered_services: list[str] = [] # Get all the service data for service in self.datastore.list_all_services(full=True): self._sync_service(service) discovered_services.append(service.name) # Find any services we have running, that are no longer in the database and remove them for stray_service in current_services - set( discovered_services): self.log.info( f'Service appears to be deleted, removing stray {stray_service}' ) stage = self.get_service_stage(stray_service) self.stop_service(stray_service, stage) self.sleep(SERVICE_SYNC_INTERVAL) def _sync_service(self, service: Service): name = service.name stage = self.get_service_stage(service.name) default_settings = self.config.core.scaler.service_defaults image_variables: defaultdict[str, str] = defaultdict(str) image_variables.update(self.config.services.image_variables) def prepare_container(docker_config: DockerConfig) -> DockerConfig: docker_config.image = Template( docker_config.image).safe_substitute(image_variables) set_keys = set(var.name for var in docker_config.environment) for var in default_settings.environment: if var.name not in set_keys: docker_config.environment.append(var) return docker_config # noinspection PyBroadException try: def disable_incompatible_service(): service.enabled = False if self.datastore.service_delta.update(service.name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service.name, { 'operation': Operation.Incompatible, 'name': service.name }) # Check if service considered compatible to run on Assemblyline? system_spec = f'{FRAMEWORK_VERSION}.{SYSTEM_VERSION}' if not service.version.startswith(system_spec): # If FW and SYS version don't prefix in the service version, we can't guarantee the service is compatible # Disable and treat it as incompatible due to service version. self.log.warning( "Disabling service with incompatible version. " f"[{service.version} != '{system_spec}.X.{service.update_channel}Y']." ) disable_incompatible_service() elif service.update_config and service.update_config.wait_for_update and not service.update_config.sources: # All signatures sources from a signature-dependent service was removed # Disable and treat it as incompatible due to service configuration relative to source management self.log.warning( "Disabling service with incompatible service configuration. " "Signature-dependent service has no signature sources.") disable_incompatible_service() if not service.enabled: self.stop_service(service.name, stage) return # Build the docker config for the dependencies. For now the dependency blob values # aren't set for the change key going to kubernetes because everything about # the dependency config should be captured in change key that the function generates # internally. A change key is set for the service deployment as that includes # things like the submission params dependency_config: dict[str, Any] = {} dependency_blobs: dict[str, str] = {} for _n, dependency in service.dependencies.items(): dependency.container = prepare_container(dependency.container) dependency_config[_n] = dependency dep_hash = get_id_from_data(dependency, length=16) dependency_blobs[ _n] = f"dh={dep_hash}v={service.version}p={service.privileged}" # Check if the service dependencies have been deployed. dependency_keys = [] updater_ready = stage == ServiceStage.Running if service.update_config: for _n, dependency in dependency_config.items(): key = self.controller.stateful_container_key( service.name, _n, dependency, '') if key: dependency_keys.append(_n + key) else: updater_ready = False # If stage is not set to running or a dependency container is missing start the setup process if not updater_ready: self.log.info(f'Preparing environment for {service.name}') # Move to the next service stage (do this first because the container we are starting may care) if service.update_config and service.update_config.wait_for_update: self._service_stage_hash.set(name, ServiceStage.Update) stage = ServiceStage.Update else: self._service_stage_hash.set(name, ServiceStage.Running) stage = ServiceStage.Running # Enable this service's dependencies before trying to launch the service containers dependency_internet = [ (name, dependency.container.allow_internet_access) for name, dependency in dependency_config.items() ] self.controller.prepare_network( service.name, service.docker_config.allow_internet_access, dependency_internet) for _n, dependency in dependency_config.items(): self.log.info(f'Launching {service.name} dependency {_n}') self.controller.start_stateful_container( service_name=service.name, container_name=_n, spec=dependency, labels={'dependency_for': service.name}, change_key=dependency_blobs.get(_n, '')) # If the conditions for running are met deploy or update service containers if stage == ServiceStage.Running: # Build the docker config for the service, we are going to either create it or # update it so we need to know what the current configuration is either way docker_config = prepare_container(service.docker_config) # Compute a blob of service properties not include in the docker config, that # should still result in a service being restarted when changed cfg_items = get_recursive_sorted_tuples(service.config) dep_keys = ''.join(sorted(dependency_keys)) config_blob = ( f"c={cfg_items}sp={service.submission_params}" f"dk={dep_keys}p={service.privileged}d={docker_config}") # Add the service to the list of services being scaled with self.profiles_lock: if name not in self.profiles: self.log.info( f"Adding " f"{f'privileged {service.name}' if service.privileged else service.name}" " to scaling") self.add_service( ServiceProfile( name=name, min_instances=default_settings.min_instances, growth=default_settings.growth, shrink=default_settings.shrink, config_blob=config_blob, dependency_blobs=dependency_blobs, backlog=default_settings.backlog, max_instances=service.licence_count, container_config=docker_config, queue=get_service_queue(name, self.redis), # Give service an extra 30 seconds to upload results shutdown_seconds=service.timeout + 30, privileged=service.privileged)) # Update RAM, CPU, licence requirements for running services else: profile = self.profiles[name] profile.max_instances = service.licence_count profile.privileged = service.privileged for dependency_name, dependency_blob in dependency_blobs.items( ): if profile.dependency_blobs[ dependency_name] != dependency_blob: self.log.info( f"Updating deployment information for {name}/{dependency_name}" ) profile.dependency_blobs[ dependency_name] = dependency_blob self.controller.start_stateful_container( service_name=service.name, container_name=dependency_name, spec=dependency_config[dependency_name], labels={'dependency_for': service.name}, change_key=dependency_blob) if profile.config_blob != config_blob: self.log.info( f"Updating deployment information for {name}") profile.container_config = docker_config profile.config_blob = config_blob self.controller.restart(profile) self.log.info( f"Deployment information for {name} replaced") except Exception: self.log.exception( f"Error applying service settings from: {service.name}") self.handle_service_error(service.name) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def stop_service(self, name: str, current_stage: ServiceStage): if current_stage != ServiceStage.Off: # Disable this service's dependencies self.controller.stop_containers(labels={'dependency_for': name}) # Mark this service as not running in the shared record self._service_stage_hash.set(name, ServiceStage.Off) # Stop any running disabled services if name in self.profiles or self.controller.get_target(name) > 0: self.log.info(f'Removing {name} from scaling') with self.profiles_lock: self.profiles.pop(name, None) self.controller.set_target(name, 0) def update_scaling(self): """Check if we need to scale any services up or down.""" pool = Pool() while self.sleep(SCALE_INTERVAL): with apm_span(self.apm_client, 'update_scaling'): # Figure out what services are expected to be running and how many with elasticapm.capture_span('read_profiles'): with self.profiles_lock: all_profiles: dict[str, ServiceProfile] = copy.deepcopy( self.profiles) raw_targets = self.controller.get_targets() targets = { _p.name: raw_targets.get(_p.name, 0) for _p in all_profiles.values() } for name, profile in all_profiles.items(): self.log.debug(f'{name}') self.log.debug( f'Instances \t{profile.min_instances} < {profile.desired_instances} | ' f'{targets[name]} < {profile.max_instances}') self.log.debug( f'Pressure \t{profile.shrink_threshold} < ' f'{profile.pressure} < {profile.growth_threshold}') # # 1. Any processes that want to release resources can always be approved first # with pool: for name, profile in all_profiles.items(): if targets[name] > profile.desired_instances: self.log.info( f"{name} wants less resources changing allocation " f"{targets[name]} -> {profile.desired_instances}" ) pool.call(self.controller.set_target, name, profile.desired_instances) targets[name] = profile.desired_instances # # 2. Any processes that aren't reaching their min_instances target must be given # more resources before anyone else is considered. # for name, profile in all_profiles.items(): if targets[name] < profile.min_instances: self.log.info( f"{name} isn't meeting minimum allocation " f"{targets[name]} -> {profile.min_instances}") pool.call(self.controller.set_target, name, profile.min_instances) targets[name] = profile.min_instances # # 3. Try to estimate available resources, and based on some metric grant the # resources to each service that wants them. While this free memory # pool might be spread across many nodes, we are going to treat it like # it is one big one, and let the orchestration layer sort out the details. # # Recalculate the amount of free resources expanding the total quantity by the overallocation free_cpu, total_cpu = self.controller.cpu_info() used_cpu = total_cpu - free_cpu free_cpu = total_cpu * self.config.core.scaler.cpu_overallocation - used_cpu free_memory, total_memory = self.controller.memory_info() used_memory = total_memory - free_memory free_memory = total_memory * self.config.core.scaler.memory_overallocation - used_memory # def trim(prof: list[ServiceProfile]): prof = [ _p for _p in prof if _p.desired_instances > targets[_p.name] ] drop = [ _p for _p in prof if _p.cpu > free_cpu or _p.ram > free_memory ] if drop: summary = {_p.name: (_p.cpu, _p.ram) for _p in drop} self.log.debug( f"Can't make more because not enough resources {summary}" ) prof = [ _p for _p in prof if _p.cpu <= free_cpu and _p.ram <= free_memory ] return prof remaining_profiles: list[ServiceProfile] = trim( list(all_profiles.values())) # The target values up until now should be in sync with the container orchestrator # create a copy, so we can track which ones change in the following loop old_targets = dict(targets) while remaining_profiles: # TODO do we need to add balancing metrics other than 'least running' for this? probably remaining_profiles.sort(key=lambda _p: targets[_p.name]) # Add one for the profile at the bottom free_memory -= remaining_profiles[ 0].container_config.ram_mb free_cpu -= remaining_profiles[ 0].container_config.cpu_cores targets[remaining_profiles[0].name] += 1 # Take out any services that should be happy now remaining_profiles = trim(remaining_profiles) # Apply those adjustments we have made back to the controller with elasticapm.capture_span('write_targets'): with pool: for name, value in targets.items(): if name not in self.profiles: # A service was probably added/removed while we were # in the middle of this function continue self.profiles[name].target_instances = value old = old_targets[name] if value != old: self.log.info( f"Scaling service {name}: {old} -> {value}" ) pool.call(self.controller.set_target, name, value) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def handle_service_error(self, service_name: str): """Handle an error occurring in the *analysis* service. Errors for core systems should simply be logged, and a best effort to continue made. For analysis services, ignore the error a few times, then disable the service. """ with self.error_count_lock: try: self.error_count[service_name].append(time.time()) except KeyError: self.error_count[service_name] = [time.time()] self.error_count[service_name] = [ _t for _t in self.error_count[service_name] if _t >= time.time() - ERROR_EXPIRY_TIME ] if len(self.error_count[service_name]) >= MAXIMUM_SERVICE_ERRORS: self.log.warning( f"Scaler has encountered too many errors trying to load {service_name}. " "The service will be permanently disabled...") if self.datastore.service_delta.update(service_name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service_name, { 'operation': Operation.Modified, 'name': service_name }) del self.error_count[service_name] def sync_metrics(self): """Check if there are any pub-sub messages we need.""" while self.sleep(METRIC_SYNC_INTERVAL): with apm_span(self.apm_client, 'sync_metrics'): # Pull service metrics from redis service_data = self.status_table.items() for host, (service, state, time_limit) in service_data.items(): # If an entry hasn't expired, take it into account if time.time() < time_limit: self.state.update( service=service, host=host, throughput=0, busy_seconds=METRIC_SYNC_INTERVAL if state == ServiceStatus.Running else 0) # If an entry expired a while ago, the host is probably not in use any more if time.time() > time_limit + 600: self.status_table.pop(host) # Download the current targets in the orchestrator while not holding the lock with self.profiles_lock: targets = { name: profile.target_instances for name, profile in self.profiles.items() } # Check the set of services that might be sitting at zero instances, and if it is, we need to # manually check if it is offline export_interval = self.config.core.metrics.export_interval with self.profiles_lock: queues = [ profile.queue for profile in self.profiles.values() if profile.queue ] lengths_list = pq_length(*queues) lengths = {_q: _l for _q, _l in zip(queues, lengths_list)} for profile_name, profile in self.profiles.items(): queue_length = lengths.get(profile.queue, 0) # Pull out statistics from the metrics regularization update = self.state.read(profile_name) if update: delta = time.time() - profile.last_update profile.update(delta=delta, backlog=queue_length, **update) # Check if we expect no messages, if so pull the queue length ourselves # since there is no heartbeat if targets.get( profile_name ) == 0 and profile.desired_instances == 0 and profile.queue: if queue_length > 0: self.log.info( f"Service at zero instances has messages: " f"{profile.name} ({queue_length} in queue)" ) profile.update(delta=export_interval, instances=0, backlog=queue_length, duty_cycle=profile.high_duty_cycle) def _timeout_kill(self, service, container): with apm_span(self.apm_client, 'timeout_kill'): self.controller.stop_container(service, container) self.status_table.pop(container) def process_timeouts(self): with concurrent.futures.ThreadPoolExecutor(10) as pool: futures = [] while self.running: message = self.scaler_timeout_queue.pop(blocking=True, timeout=1) if not message: continue with apm_span(self.apm_client, 'process_timeouts'): # Process new messages self.log.info( f"Killing service container: {message['container']} running: {message['service']}" ) futures.append( pool.submit(self._timeout_kill, message['service'], message['container'])) # Process finished finished = [_f for _f in futures if _f.done()] futures = [_f for _f in futures if _f not in finished] for _f in finished: exception = _f.exception() if exception is not None: self.log.error( f"Exception trying to stop timed out service container: {exception}" ) def export_metrics(self): while self.sleep(self.config.logging.export_interval): with apm_span(self.apm_client, 'export_metrics'): service_metrics = {} with self.profiles_lock: for service_name, profile in self.profiles.items(): service_metrics[service_name] = { 'running': profile.running_instances, 'target': profile.target_instances, 'minimum': profile.min_instances, 'maximum': profile.instance_limit, 'dynamic_maximum': profile.max_instances, 'queue': profile.queue_length, 'duty_cycle': profile.duty_cycle, 'pressure': profile.pressure } for service_name, metrics in service_metrics.items(): export_metrics_once(service_name, Status, metrics, host=HOSTNAME, counter_type='scaler_status', config=self.config, redis=self.redis) memory, memory_total = self.controller.memory_info() cpu, cpu_total = self.controller.cpu_info() metrics = { 'memory_total': memory_total, 'cpu_total': cpu_total, 'memory_free': memory, 'cpu_free': cpu } export_metrics_once('scaler', Metrics, metrics, host=HOSTNAME, counter_type='scaler', config=self.config, redis=self.redis) def log_container_events(self): """The service status table may have references to containers that have crashed. Try to remove them all.""" while self.sleep(CONTAINER_EVENTS_LOG_INTERVAL): with apm_span(self.apm_client, 'log_container_events'): for message in self.controller.new_events(): self.log.warning("Container Event :: " + message)
class Alerter(ServerBase): def __init__(self): super().__init__('assemblyline.alerter') # Publish counters to the metrics sink. self.counter = MetricsFactory('alerter', Metrics) self.datastore = forge.get_datastore(self.config) self.persistent_redis = get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.process_alert_message = forge.get_process_alert_message() self.running = False self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis) if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="alerter") else: self.apm_client = None def close(self): if self.counter: self.counter.stop() if self.apm_client: elasticapm.uninstrument() def run_once(self): alert = self.alert_queue.pop(timeout=1) if not alert: return # Start of process alert transaction if self.apm_client: self.apm_client.begin_transaction('Process alert message') self.counter.increment('received') try: alert_type = self.process_alert_message(self.counter, self.datastore, self.log, alert) # End of process alert transaction (success) if self.apm_client: self.apm_client.end_transaction(alert_type, 'success') return alert_type except Exception as ex: # pylint: disable=W0703 retries = alert['alert_retries'] = alert.get('alert_retries', 0) + 1 if retries > MAX_RETRIES: self.log.exception(f'Max retries exceeded for: {alert}') else: self.alert_queue.push(alert) if 'Submission not finalized' not in str(ex): self.log.exception( f'Unhandled exception processing: {alert}') # End of process alert transaction (failure) if self.apm_client: self.apm_client.end_transaction('unknown', 'exception') def try_run(self): while self.running: self.heartbeat() self.run_once()
def watch_message_queue(self, queue_id, user_info): queue = NamedQueue(queue_id, private=True) max_retry = 30 retry = 0 while queue_id in self.watch_queues: msg = queue.pop(timeout=1) if msg is None: retry += 1 if retry >= max_retry: self.socketio.emit( 'error', { 'status_code': 503, 'msg': "Dispatcher does not seem to be responding..." }, room=queue_id, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - " f"Max retry reach for queue: {queue_id}") break continue retry = 0 try: status = msg['status'] key = msg.get('cache_key', None) except (KeyError, ValueError, TypeError): LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - Unexpected message received for " f"queue {queue_id}: {msg}") continue if status == 'START': self.socketio.emit('start', { 'status_code': 200, 'msg': "Start listening..." }, room=queue_id, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - " f"Stating processing message on queue: {queue_id}") max_retry = 300 elif status == 'STOP': self.socketio.emit( 'stop', { 'status_code': 200, 'msg': "All messages received, closing queue..." }, room=queue_id, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - " f"Stopping monitoring queue: {queue_id}") break elif status == 'OK': self.socketio.emit('cachekey', { 'status_code': 200, 'msg': key }, room=queue_id, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - Sending result key: {key}" ) elif status == 'FAIL': self.socketio.emit('cachekeyerr', { 'status_code': 200, 'msg': key }, room=queue_id, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - Sending error key: {key}" ) else: LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - Unexpected message received for " f"queue {queue_id}: {msg}") with self.connections_lock: self.watch_queues.pop(queue_id, None) self.socketio.close_room(queue_id) LOGGER.info(f"SocketIO:{self.namespace} - {user_info['display']} - " f"Watch queue thread terminated for queue: {queue_id}")
def test_deduplication(core): # ------------------------------------------------------------------------------- # Submit two identical jobs, check that they get deduped by ingester sha, size = ready_body(core) for _ in range(2): core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict( description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], ), notification=dict(queue='output-queue-one', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-output-queue-one', core.redis) first_task = notification_queue.pop(timeout=5) second_task = notification_queue.pop(timeout=5) # One of the submission will get processed fully assert first_task is not None first_task = IngestTask(first_task) first_submission: Submission = core.ds.submission.get( first_task.submission.sid) assert first_submission.state == 'completed' assert len(first_submission.files) == 1 assert len(first_submission.errors) == 0 assert len(first_submission.results) == 4 # The other will get processed as a duplicate # (Which one is the 'real' one and which is the duplicate isn't important for our purposes) second_task = IngestTask(second_task) assert second_task.submission.sid == first_task.submission.sid # ------------------------------------------------------------------------------- # Submit the same body, but change a parameter so the cache key misses, core.ingest_queue.push( SubmissionInput( dict(metadata={}, params=dict(description="file abc123", services=dict(selected=''), submitter='user', groups=['user'], max_extracted=10000), notification=dict(queue='2', threshold=0), files=[dict(sha256=sha, size=size, name='abc123')])).as_primitives()) notification_queue = NamedQueue('nq-2', core.redis) third_task = notification_queue.pop(timeout=5) assert third_task # The third task should not be deduplicated by ingester, so will have a different submission third_task = IngestTask(third_task) third_submission: Submission = core.ds.submission.get( third_task.submission.sid) assert third_submission.state == 'completed' assert first_submission.sid != third_submission.sid assert len(third_submission.files) == 1 assert len(third_submission.results) == 4
class Ingester(ThreadedCoreBase): def __init__(self, datastore=None, logger=None, classification=None, redis=None, persistent_redis=None, metrics_name='ingester', config=None): super().__init__('assemblyline.ingester', logger, redis=redis, redis_persist=persistent_redis, datastore=datastore, config=config) # Cache the user groups self.cache_lock = threading.RLock() self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value: dict[str, int] = constants.PRIORITIES self.priority_range: dict[str, Tuple[int, int]] = constants.PRIORITY_RANGES self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.redis_persist) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.redis_persist) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.redis_persist) # Internal, timeout watch queue self.timeout_queue: PriorityQueue[str] = PriorityQueue( 'm-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.redis_persist) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis) if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="ingester") else: self.apm_client = None def try_run(self): threads_to_maintain = { 'Retries': self.handle_retries, 'Timeouts': self.handle_timeouts } threads_to_maintain.update({ f'Complete_{n}': self.handle_complete for n in range(COMPLETE_THREADS) }) threads_to_maintain.update( {f'Ingest_{n}': self.handle_ingest for n in range(INGEST_THREADS)}) threads_to_maintain.update( {f'Submit_{n}': self.handle_submit for n in range(SUBMIT_THREADS)}) self.maintain_threads(threads_to_maintain) def handle_ingest(self): cpu_mark = time.process_time() time_mark = time.time() # Move from ingest to unique and waiting queues. # While there are entries in the ingest queue we consume chunk_size # entries at a time and move unique entries to uniqueq / queued and # duplicates to their own queues / waiting. while self.running: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) message = self.ingest_queue.pop(timeout=1) cpu_mark = time.process_time() time_mark = time.time() if not message: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') try: if 'submission' in message: # A retried task task = IngestTask(message) else: # A new submission sub = MessageSubmission(message) task = IngestTask(dict( submission=sub, ingest_id=sub.sid, )) task.submission.sid = None # Reset to new random uuid # Write all input to the traffic queue self.traffic_queue.publish( SubmissionMessage({ 'msg': sub, 'msg_type': 'SubmissionIngested', 'sender': 'ingester', }).as_primitives()) except (ValueError, TypeError) as error: self.counter.increment('error') self.log.exception( f"Dropped ingest submission {message} because {str(error)}" ) # End of ingest message (value_error) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'value_error') continue self.ingest(task) # End of ingest message (success) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'success') def handle_submit(self): time_mark, cpu_mark = time.time(), time.process_time() while self.running: # noinspection PyBroadException try: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) # Check if there is room for more submissions length = self.scanning.length() if length >= self.config.core.ingester.max_inflight: self.sleep(0.1) time_mark, cpu_mark = time.time(), time.process_time() continue raw = self.unique_queue.blocking_pop(timeout=3) time_mark, cpu_mark = time.time(), time.process_time() if not raw: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') task = IngestTask(raw) # Check if we need to drop a file for capacity reasons, but only if the # number of files in flight is alreay over 80% if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop( task): # End of ingest message (dropped) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'dropped') continue if self.is_whitelisted(task): # End of ingest message (whitelisted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'whitelisted') continue # Check if this file has been previously processed. pprevious, previous, score, scan_key = None, None, None, None if not task.submission.params.ignore_cache: pprevious, previous, score, scan_key = self.check(task) else: scan_key = self.stamp_filescore_key(task) # If it HAS been previously processed, we are dealing with a resubmission # finalize will decide what to do, and put the task back in the queue # rewritten properly if we are going to run it again if previous: if not task.submission.params.services.resubmit and not pprevious: self.log.warning( f"No psid for what looks like a resubmission of " f"{task.submission.files[0].sha256}: {scan_key}") self.finalize(pprevious, previous, score, task) # End of ingest message (finalized) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'finalized') continue # We have decided this file is worth processing # Add the task to the scanning table, this is atomic across all submit # workers, so if it fails, someone beat us to the punch, record the file # as a duplicate then. if not self.scanning.add(scan_key, task.as_primitives()): self.log.debug('Duplicate %s', task.submission.files[0].sha256) self.counter.increment('duplicates') self.duplicate_queue.push(_dup_prefix + scan_key, task.as_primitives()) # End of ingest message (duplicate) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'duplicate') continue # We have managed to add the task to the scan table, so now we go # ahead with the submission process try: self.submit(task) # End of ingest message (submitted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'submitted') continue except Exception as _ex: # For some reason (contained in `ex`) we have failed the submission # The rest of this function is error handling/recovery ex = _ex # traceback = _ex.__traceback__ self.counter.increment('error') should_retry = True if isinstance(ex, CorruptedFileStoreException): self.log.error( "Submission for file '%s' failed due to corrupted " "filestore: %s" % (task.sha256, str(ex))) should_retry = False elif isinstance(ex, DataStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed due to " "data store error:\n%s" % (task.sha256, trace)) elif not isinstance(ex, FileStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed: %s" % (task.sha256, trace)) task = IngestTask(self.scanning.pop(scan_key)) if not task: self.log.error('No scanning entry for for %s', task.sha256) # End of ingest message (no_scan_entry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'no_scan_entry') continue if not should_retry: # End of ingest message (cannot_retry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'cannot_retry') continue self.retry(task, scan_key, ex) # End of ingest message (retry) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'retried') except Exception: self.log.exception("Unexpected error") # End of ingest message (exception) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'exception') def handle_complete(self): while self.running: result = self.complete_queue.pop(timeout=3) if not result: continue cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') sub = DatabaseSubmission(result) self.completed(sub) # End of ingest message (success) if self.apm_client: elasticapm.label(sid=sub.sid) self.apm_client.end_transaction('ingest_complete', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_retries(self): tasks = [] while self.sleep(0 if tasks else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_retries') tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(), num=100) for task in tasks: self.ingest_queue.push(task) # End of ingest message (success) if self.apm_client: elasticapm.label(retries=len(tasks)) self.apm_client.end_transaction('ingest_retries', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_timeouts(self): timeouts = [] while self.sleep(0 if timeouts else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_timeouts') timeouts = self.timeout_queue.dequeue_range( upper_limit=isotime.now(), num=100) for scan_key in timeouts: # noinspection PyBroadException try: actual_timeout = False # Remove the entry from the hash of submissions in progress. entry = self.scanning.pop(scan_key) if entry: actual_timeout = True self.log.error("Submission timed out for %s: %s", scan_key, str(entry)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if dup: actual_timeout = True while dup: self.log.error("Submission timed out for %s: %s", scan_key, str(dup)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if actual_timeout: self.counter.increment('timed_out') except Exception: self.log.exception("Problem timing out %s:", scan_key) # End of ingest message (success) if self.apm_client: elasticapm.label(timeouts=len(timeouts)) self.apm_client.end_transaction('ingest_timeouts', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def get_groups_from_user(self, username: str) -> List[str]: # Reset the group cache at the top of each hour if time.time() // HOUR_IN_SECONDS > self._user_groups_reset: self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS # Get the groups for this user if not known if username not in self._user_groups: user_data = self.datastore.user.get(username) if user_data: self._user_groups[username] = user_data.groups else: self._user_groups[username] = [] return self._user_groups[username] def ingest(self, task: IngestTask): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Task received for processing" ) # Load a snapshot of ingest parameters as of right now. max_file_size = self.config.submission.max_file_size param = task.params self.counter.increment('bytes_ingested', increment_by=task.file_size) self.counter.increment('submissions_ingested') if any(len(file.sha256) != 64 for file in task.submission.files): self.log.error( f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped") self.send_notification(task, failure="Invalid sha256", logfunc=self.log.warning) return # Clean up metadata strings, since we may delete some, iterate on a copy of the keys for key in list(task.submission.metadata.keys()): value = task.submission.metadata[key] meta_size = len(value) if meta_size > self.config.submission.max_metadata_length: self.log.info( f'[{task.ingest_id} :: {task.sha256}] ' f'Removing {key} from metadata because value is too big') task.submission.metadata.pop(key) if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop: task.failure = f"File too large ({task.file_size} > {max_file_size})" self._notify_drop(task) self.counter.increment('skipped') self.log.error( f"[{task.ingest_id} :: {task.sha256}] {task.failure}") return # Set the groups from the user, if they aren't already set if not task.params.groups: task.params.groups = self.get_groups_from_user( task.params.submitter) # Check if this file is already being processed self.stamp_filescore_key(task) pprevious, previous, score = None, None, None if not param.ignore_cache: pprevious, previous, score, _ = self.check(task, count_miss=False) # Assign priority. low_priority = self.is_low_priority(task) priority = param.priority if priority < 0: priority = self.priority_value['medium'] if score is not None: priority = self.priority_value['low'] for level, threshold in self.threshold_value.items(): if score >= threshold: priority = self.priority_value[level] break elif low_priority: priority = self.priority_value['low'] # Reduce the priority by an order of magnitude for very old files. current_time = now() if priority and self.expired( current_time - task.submission.time.timestamp(), 0): priority = (priority / 10) or 1 param.priority = priority # Do this after priority has been assigned. # (So we don't end up dropping the resubmission). if previous: self.counter.increment('duplicates') self.finalize(pprevious, previous, score, task) # On cache hits of any kind we want to send out a completed message self.traffic_queue.publish( SubmissionMessage({ 'msg': task.submission, 'msg_type': 'SubmissionCompleted', 'sender': 'ingester', }).as_primitives()) return if self.drop(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped") return if self.is_whitelisted(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted") return self.unique_queue.push(priority, task.as_primitives()) def check( self, task: IngestTask, count_miss=True ) -> Tuple[Optional[str], Optional[str], Optional[float], str]: key = self.stamp_filescore_key(task) with self.cache_lock: result = self.cache.get(key, None) if result: self.counter.increment('cache_hit_local') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Local cache hit') else: result = self.datastore.filescore.get_if_exists(key) if result: self.counter.increment('cache_hit') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Remote cache hit') else: if count_miss: self.counter.increment('cache_miss') return None, None, None, key with self.cache_lock: self.cache[key] = result current_time = now() age = current_time - result.time errors = result.errors if self.expired(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired" ) self.counter.increment('cache_expired') self.cache.pop(key, None) self.datastore.filescore.delete(key) return None, None, None, key elif self.stale(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale" ) self.counter.increment('cache_stale') return None, None, result.score, key return result.psid, result.sid, result.score, key def stop(self): super().stop() if self.apm_client: elasticapm.uninstrument() self.submit_client.stop() def stale(self, delta: float, errors: int): if errors: return delta >= self.config.core.ingester.incomplete_stale_after_seconds else: return delta >= self.config.core.ingester.stale_after_seconds @staticmethod def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str: if not sha256: sha256 = task.submission.files[0].sha256 key = task.submission.scan_key if not key: key = task.params.create_filescore_key(sha256) task.submission.scan_key = key return key def completed(self, sub: DatabaseSubmission): """Invoked when notified that a submission has completed.""" # There is only one file in the submissions we have made sha256 = sub.files[0].sha256 scan_key = sub.scan_key if not scan_key: self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission missing scan key") scan_key = sub.params.create_filescore_key(sha256) raw = self.scanning.pop(scan_key) psid = sub.params.psid score = sub.max_score sid = sub.sid if not raw: # Some other worker has already popped the scanning queue? self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission completed twice") return scan_key task = IngestTask(raw) task.submission.sid = sid errors = sub.error_count file_count = sub.file_count self.counter.increment('submissions_completed') self.counter.increment('files_completed', increment_by=file_count) self.counter.increment('bytes_completed', increment_by=task.file_size) with self.cache_lock: fs = self.cache[scan_key] = FileScore({ 'expiry_ts': now(self.config.core.ingester.cache_dtl * 24 * 60 * 60), 'errors': errors, 'psid': psid, 'score': score, 'sid': sid, 'time': now(), }) self.datastore.filescore.save(scan_key, fs) self.finalize(psid, sid, score, task) def exhaust() -> Iterable[IngestTask]: while True: res = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if res is None: break res = IngestTask(res) res.submission.sid = sid yield res # You may be tempted to remove the assignment to dups and use the # value directly in the for loop below. That would be a mistake. # The function finalize may push on the duplicate queue which we # are pulling off and so condensing those two lines creates a # potential infinite loop. dups = [dup for dup in exhaust()] for dup in dups: self.finalize(psid, sid, score, dup) return scan_key def send_notification(self, task: IngestTask, failure=None, logfunc=None): if logfunc is None: logfunc = self.log.info if failure: task.failure = failure failure = task.failure if failure: logfunc("%s: %s", failure, str(task.json())) if not task.submission.notification.queue: return note_queue = _notification_queue_prefix + task.submission.notification.queue threshold = task.submission.notification.threshold if threshold is not None and task.score is not None and task.score < threshold: return q = self.notification_queues.get(note_queue, None) if not q: self.notification_queues[note_queue] = q = NamedQueue( note_queue, self.redis_persist) q.push(task.as_primitives()) def expired(self, delta: float, errors) -> bool: if errors: return delta >= self.config.core.ingester.incomplete_expire_after_seconds else: return delta >= self.config.core.ingester.expire_after def drop(self, task: IngestTask) -> bool: priority = task.params.priority sample_threshold = self.config.core.ingester.sampling_at dropped = False if priority <= _min_priority: dropped = True else: for level, rng in self.priority_range.items(): if rng[0] <= priority <= rng[1] and level in sample_threshold: dropped = must_drop(self.unique_queue.count(*rng), sample_threshold[level]) break if not dropped: if task.file_size > self.config.submission.max_file_size or task.file_size == 0: dropped = True if task.params.never_drop or not dropped: return False task.failure = 'Skipped' self._notify_drop(task) self.counter.increment('skipped') return True def _notify_drop(self, task: IngestTask): self.send_notification(task) c12n = task.params.classification expiry = now_as_iso(86400) sha256 = task.submission.files[0].sha256 self.datastore.save_or_freshen_file(sha256, {'sha256': sha256}, expiry, c12n, redis=self.redis) def is_whitelisted(self, task: IngestTask): reason, hit = self.get_whitelist_verdict(self.whitelist, task) hit = {x: dotdump(safe_str(y)) for x, y in hit.items()} sha256 = task.submission.files[0].sha256 if not reason: with self.whitelisted_lock: reason = self.whitelisted.get(sha256, None) if reason: hit = 'cached' if reason: if hit != 'cached': with self.whitelisted_lock: self.whitelisted[sha256] = reason task.failure = "Whitelisting due to reason %s (%s)" % (dotdump( safe_str(reason)), hit) self._notify_drop(task) self.counter.increment('whitelisted') return reason def submit(self, task: IngestTask): self.submit_client.submit( submission_obj=task.submission, completed_queue=COMPLETE_QUEUE_NAME, ) self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key) self.log.info( f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis" ) def retry(self, task: IngestTask, scan_key: str, ex): current_time = now() retries = task.retries + 1 if retries > _max_retries: trace = '' if ex: trace = ': ' + get_stacktrace_info(ex) self.log.error( f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}' ) self.duplicate_queue.delete(_dup_prefix + scan_key) elif self.expired(current_time - task.ingest_time.timestamp(), 0): self.log.info( f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission' ) self.duplicate_queue.delete(_dup_prefix + scan_key) else: self.log.info( f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})' ) task.retries = retries self.retry_queue.push(int(now(_retry_delay)), task.as_primitives()) def finalize(self, psid: str, sid: str, score: float, task: IngestTask): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed") if psid: task.params.psid = psid task.score = score task.submission.sid = sid selected = task.params.services.selected resubmit_to = task.params.services.resubmit resubmit_selected = determine_resubmit_selected(selected, resubmit_to) will_resubmit = resubmit_selected and should_resubmit(score) if will_resubmit: task.extended_scan = 'submitted' task.params.psid = None if self.is_alert(task, score): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Notifying alerter " f"to {'update' if task.params.psid else 'create'} an alert") self.alert_queue.push(task.as_primitives()) self.send_notification(task) if will_resubmit: self.log.info( f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis" ) task.params.psid = sid task.submission.sid = None task.submission.scan_key = None task.params.services.resubmit = [] task.params.services.selected = resubmit_selected self.unique_queue.push(task.params.priority, task.as_primitives()) def is_alert(self, task: IngestTask, score: float) -> bool: if not task.params.generate_alert: return False if score < self.config.core.alerter.threshold: return False return True