class Dispatcher: def __init__(self, datastore, redis, redis_persist, logger, counter_name='dispatcher'): # Load the datastore collections that we are going to be using self.datastore: AssemblylineDatastore = datastore self.log: logging.Logger = logger self.submissions: Collection = datastore.submission self.results: Collection = datastore.result self.errors: Collection = datastore.error self.files: Collection = datastore.file # Create a config cache that will refresh config values periodically self.config: Config = forge.get_config() # Connect to all of our persistent redis structures self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) # Build some utility classes self.scheduler = Scheduler(datastore, self.config, self.redis) self.classification_engine = forge.get_classification() self.timeout_watcher = WatcherClient(self.redis_persist) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.file_queue = NamedQueue(FILE_QUEUE, self.redis) self._nonper_other_queues = {} self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist) self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) # Publish counters to the metrics sink. self.counter = MetricsFactory(metrics_type='dispatcher', schema=Metrics, name=counter_name, redis=self.redis, config=self.config) def volatile_named_queue(self, name: str) -> NamedQueue: if name not in self._nonper_other_queues: self._nonper_other_queues[name] = NamedQueue(name, self.redis) return self._nonper_other_queues[name] def dispatch_submission(self, task: SubmissionTask): """ Find any files associated with a submission and dispatch them if they are not marked as in progress. If all files are finished, finalize the submission. This version of dispatch submission doesn't verify each result, but assumes that the dispatch table has been kept up to date by other components. Preconditions: - File exists in the filestore and file collection in the datastore - Submission is stored in the datastore """ submission = task.submission sid = submission.sid if not self.active_submissions.exists(sid): self.log.info(f"[{sid}] New submission received") self.active_submissions.add(sid, task.as_primitives()) else: self.log.info(f"[{sid}] Received a pre-existing submission, check if it is complete") # Refresh the watch, this ensures that this function will be called again # if something goes wrong with one of the files, and it never gets invoked by dispatch_file. self.timeout_watcher.touch(key=sid, timeout=int(self.config.core.dispatcher.timeout), queue=SUBMISSION_QUEUE, message={'sid': sid}) # Refresh the quota hold if submission.params.quota_item and submission.params.submitter: self.log.info(f"[{sid}] Submission will count towards {submission.params.submitter.upper()} quota") Hash('submissions-' + submission.params.submitter, self.redis_persist).add(sid, isotime.now_as_iso()) # Open up the file/service table for this submission dispatch_table = DispatchHash(submission.sid, self.redis, fetch_results=True) file_parents = dispatch_table.file_tree() # Load the file tree data as well # All the submission files, and all the file_tree files, to be sure we don't miss any incomplete children unchecked_hashes = [submission_file.sha256 for submission_file in submission.files] unchecked_hashes = list(set(unchecked_hashes) | set(file_parents.keys())) # Using the file tree we can recalculate the depth of any file depth_limit = self.config.submission.max_extraction_depth file_depth = depths_from_tree(file_parents) # Try to find all files, and extracted files, and create task objects for them # (we will need the file data anyway for checking the schedule later) max_files = len(submission.files) + submission.params.max_extracted unchecked_files = [] # Files that haven't been checked yet try: for sha, file_data in self.files.multiget(unchecked_hashes).items(): unchecked_files.append(FileTask(dict( sid=sid, min_classification=task.submission.classification, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=file_depth.get(sha, 0), max_files=max_files ))) except MultiKeyError as missing: errors = [] for file_sha in missing.keys: error = Error(dict( archive_ts=submission.archive_ts, expiry_ts=submission.expiry_ts, response=dict( message="Submission couldn't be completed due to missing file.", service_name="dispatcher", service_tool_version='4', service_version='4', status="FAIL_NONRECOVERABLE", ), sha256=file_sha, type='UNKNOWN' )) error_key = error.build_key(service_tool_version=sid) self.datastore.error.save(error_key, error) errors.append(error_key) return self.cancel_submission(task, errors, file_parents) # Files that have already been encountered, but may or may not have been processed yet # encountered_files = {file.sha256 for file in submission.files} pending_files = {} # Files that have not yet been processed # Track information about the results as we hit them file_scores: Dict[str, int] = {} # # Load the current state of the dispatch table in one go rather than one at a time in the loop prior_dispatches = dispatch_table.all_dispatches() # found should be added to the unchecked files if they haven't been encountered already for file_task in unchecked_files: sha = file_task.file_info.sha256 schedule = self.build_schedule(dispatch_table, submission, sha, file_task.file_info.type) while schedule: stage = schedule.pop(0) for service_name in stage: # Only active services should be in this dict, so if a service that was placed in the # schedule is now missing it has been disabled or taken offline. service = self.scheduler.services.get(service_name) if not service: continue # If the service is still marked as 'in progress' runtime = time.time() - prior_dispatches.get(sha, {}).get(service_name, 0) if runtime < service.timeout: pending_files[sha] = file_task continue # It hasn't started, has timed out, or is finished, see if we have a result result_row = dispatch_table.finished(sha, service_name) # No result found, mark the file as incomplete if not result_row: pending_files[sha] = file_task continue if not submission.params.ignore_filtering and result_row.drop: schedule.clear() # The process table is marked that a service has been abandoned due to errors if result_row.is_error: continue # Collect information about the result file_scores[sha] = file_scores.get(sha, 0) + result_row.score # Using the file tree find the most shallow parent of the given file def lowest_parent(_sha): # A root file won't have any parents in the dict if _sha not in file_parents or None in file_parents[_sha]: return None return min((file_depth.get(parent, depth_limit), parent) for parent in file_parents[_sha])[1] # Filter out things over the depth limit pending_files = {sha: ft for sha, ft in pending_files.items() if ft.depth < depth_limit} # Filter out files based on the extraction limits pending_files = {sha: ft for sha, ft in pending_files.items() if dispatch_table.add_file(sha, max_files, lowest_parent(sha))} # If there are pending files, then at least one service, on at least one # file isn't done yet, and hasn't been filtered by any of the previous few steps # poke those files if pending_files: self.log.debug(f"[{sid}] Dispatching {len(pending_files)} files: {list(pending_files.keys())}") for file_task in pending_files.values(): self.file_queue.push(file_task.as_primitives()) else: self.log.debug(f"[{sid}] Finalizing submission.") max_score = max(file_scores.values()) if file_scores else 0 # Submissions with no results have no score self.finalize_submission(task, max_score, file_scores.keys()) def _cleanup_submission(self, task: SubmissionTask, file_list: List[str]): """Clean up code that is the same for canceled and finished submissions""" submission = task.submission sid = submission.sid # Erase the temporary data which may have accumulated during processing for file_hash in file_list: hash_name = get_temporary_submission_data_name(sid, file_hash=file_hash) ExpiringHash(hash_name, host=self.redis).delete() if submission.params.quota_item and submission.params.submitter: self.log.info(f"[{sid}] Submission no longer counts toward {submission.params.submitter.upper()} quota") Hash('submissions-' + submission.params.submitter, self.redis_persist).pop(sid) if task.completed_queue: self.volatile_named_queue(task.completed_queue).push(submission.as_primitives()) # Send complete message to any watchers. watcher_list = ExpiringSet(make_watcher_list_name(sid), host=self.redis) for w in watcher_list.members(): NamedQueue(w).push(WatchQueueMessage({'status': 'STOP'}).as_primitives()) # Clear the timeout watcher watcher_list.delete() self.timeout_watcher.clear(sid) self.active_submissions.pop(sid) # Count the submission as 'complete' either way self.counter.increment('submissions_completed') def cancel_submission(self, task: SubmissionTask, errors, file_list): """The submission is being abandoned, delete everything, write failed state.""" submission = task.submission sid = submission.sid # Pull down the dispatch table and clear it from redis dispatch_table = DispatchHash(submission.sid, self.redis) dispatch_table.delete() submission.classification = submission.params.classification submission.error_count = len(errors) submission.errors = errors submission.state = 'failed' submission.times.completed = isotime.now_as_iso() self.submissions.save(sid, submission) self._cleanup_submission(task, file_list) self.log.error(f"[{sid}] Failed") def finalize_submission(self, task: SubmissionTask, max_score, file_list): """All of the services for all of the files in this submission have finished or failed. Update the records in the datastore, and flush the working data from redis. """ submission = task.submission sid = submission.sid # Pull down the dispatch table and clear it from redis dispatch_table = DispatchHash(submission.sid, self.redis) all_results = dispatch_table.all_results() errors = dispatch_table.all_extra_errors() dispatch_table.delete() # Sort the errors out of the results results = [] for row in all_results.values(): for status in row.values(): if status.is_error: errors.append(status.key) elif status.bucket == 'result': results.append(status.key) else: self.log.warning(f"[{sid}] Unexpected service output bucket: {status.bucket}/{status.key}") submission.classification = submission.params.classification submission.error_count = len(errors) submission.errors = errors submission.file_count = len(file_list) submission.results = results submission.max_score = max_score submission.state = 'completed' submission.times.completed = isotime.now_as_iso() self.submissions.save(sid, submission) self._cleanup_submission(task, file_list) self.log.info(f"[{sid}] Completed; files: {len(file_list)} results: {len(results)} " f"errors: {len(errors)} score: {max_score}") def dispatch_file(self, task: FileTask): """ Handle a message describing a file to be processed. This file may be: - A new submission or extracted file. - A file that has just completed a stage of processing. - A file that has not completed a a stage of processing, but this call has been triggered by a timeout or similar. If the file is totally new, we will setup a dispatch table, and fill it in. Once we make/load a dispatch table, we will dispatch whichever group the table shows us hasn't been completed yet. When we dispatch to a service, we check if the task is already in the dispatch queue. If it isn't proceed normally. If it is, check that the service is still online. """ # Read the message content file_hash = task.file_info.sha256 active_task = self.active_submissions.get(task.sid) if active_task is None: self.log.warning(f"[{task.sid}] Untracked submission is being processed") return submission_task = SubmissionTask(active_task) submission = submission_task.submission # Refresh the watch on the submission, we are still working on it self.timeout_watcher.touch(key=task.sid, timeout=int(self.config.core.dispatcher.timeout), queue=SUBMISSION_QUEUE, message={'sid': task.sid}) # Open up the file/service table for this submission dispatch_table = DispatchHash(task.sid, self.redis, fetch_results=True) # Load things that we will need to fill out the file_tags = ExpiringSet(task.get_tag_set_name(), host=self.redis) file_tags_data = file_tags.members() temporary_submission_data = ExpiringHash(task.get_temporary_submission_data_name(), host=self.redis) temporary_data = [dict(name=row[0], value=row[1]) for row in temporary_submission_data.items().items()] # Calculate the schedule for the file schedule = self.build_schedule(dispatch_table, submission, file_hash, task.file_info.type) started_stages = [] # Go through each round of the schedule removing complete/failed services # Break when we find a stage that still needs processing outstanding = {} score = 0 errors = 0 while schedule and not outstanding: stage = schedule.pop(0) started_stages.append(stage) for service_name in stage: service = self.scheduler.services.get(service_name) if not service: continue # Load the results, if there are no results, then the service must be dispatched later # Don't look at if it has been dispatched, as multiple dispatches are fine, # but missing a dispatch isn't. finished = dispatch_table.finished(file_hash, service_name) if not finished: outstanding[service_name] = service continue # If the service terminated in an error, count the error and continue if finished.is_error: errors += 1 continue # if the service finished, count the score, and check if the file has been dropped score += finished.score if not submission.params.ignore_filtering and finished.drop: schedule.clear() if schedule: # If there are still stages in the schedule, over write them for next time dispatch_table.schedules.set(file_hash, started_stages) # Try to retry/dispatch any outstanding services if outstanding: self.log.info(f"[{task.sid}] File {file_hash} sent to services : {', '.join(list(outstanding.keys()))}") for service_name, service in outstanding.items(): # Find the actual file name from the list of files in submission filename = None for file in submission.files: if task.file_info.sha256 == file.sha256: filename = file.name break # Build the actual service dispatch message config = self.build_service_config(service, submission) service_task = ServiceTask(dict( sid=task.sid, metadata=submission.metadata, min_classification=task.min_classification, service_name=service_name, service_config=config, fileinfo=task.file_info, filename=filename or task.file_info.sha256, depth=task.depth, max_files=task.max_files, ttl=submission.params.ttl, ignore_cache=submission.params.ignore_cache, ignore_dynamic_recursion_prevention=submission.params.ignore_dynamic_recursion_prevention, tags=file_tags_data, temporary_submission_data=temporary_data, deep_scan=submission.params.deep_scan, priority=submission.params.priority, )) dispatch_table.dispatch(file_hash, service_name) queue = get_service_queue(service_name, self.redis) queue.push(service_task.priority, service_task.as_primitives()) else: # There are no outstanding services, this file is done # clean up the tags file_tags.delete() # If there are no outstanding ANYTHING for this submission, # send a message to the submission dispatcher to finalize self.counter.increment('files_completed') if dispatch_table.all_finished(): self.log.info(f"[{task.sid}] Finished processing file '{file_hash}' starting submission finalization.") self.submission_queue.push({'sid': submission.sid}) else: self.log.info(f"[{task.sid}] Finished processing file '{file_hash}'. Other files are not finished.") def build_schedule(self, dispatch_hash: DispatchHash, submission: Submission, file_hash: str, file_type: str) -> List[List[str]]: """Rather than rebuilding the schedule every time we see a file, build it once and cache in redis.""" cached_schedule = dispatch_hash.schedules.get(file_hash) if not cached_schedule: # Get the schedule for that file type based on the submission parameters obj_schedule = self.scheduler.build_schedule(submission, file_type) # The schedule built by the scheduling tool has the service objects, we just want the names for now cached_schedule = [list(stage.keys()) for stage in obj_schedule] dispatch_hash.schedules.add(file_hash, cached_schedule) return cached_schedule @classmethod def build_service_config(cls, service: Service, submission: Submission) -> Dict[str, str]: """Prepare the service config that will be used downstream. v3 names: get_service_params get_config_data """ # Load the default service config params = {x.name: x.default for x in service.submission_params} # Over write it with values from the submission if service.name in submission.params.service_spec: params.update(submission.params.service_spec[service.name]) return params
class DispatchClient: def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None): self.config = forge.get_config() self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.timeout_watcher = WatcherClient(redis_persist) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.file_queue = NamedQueue(FILE_QUEUE, self.redis) self.ds = datastore or forge.get_datastore(self.config) self.log = logger or logging.getLogger( "assemblyline.dispatching.client") self.results = datastore.result self.errors = datastore.error self.files = datastore.file self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=redis_persist) self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) self.service_data = cast(Dict[str, Service], CachedObject(self._get_services)) def _get_services(self): # noinspection PyUnresolvedReferences return {x.name: x for x in self.ds.list_all_services(full=True)} def dispatch_submission(self, submission: Submission, completed_queue: str = None): """Insert a submission into the dispatching system. Note: You probably actually want to use the SubmissionTool Prerequsits: - submission should already be saved in the datastore - files should already be in the datastore and filestore """ self.submission_queue.push( SubmissionTask( dict( submission=submission, completed_queue=completed_queue, )).as_primitives()) def outstanding_services(self, sid) -> Dict[str, int]: """ List outstanding services for a given submission and the number of file each of them still have to process. :param sid: Submission ID :return: Dictionary of services and number of files remaining per services e.g. {"SERVICE_NAME": 1, ... } """ # Download the entire status table from redis dispatch_hash = DispatchHash(sid, self.redis) all_service_status = dispatch_hash.all_results() all_files = dispatch_hash.all_files() self.log.info( f"[{sid}] Listing outstanding services {len(all_files)} files " f"and {len(all_service_status)} entries found") output: Dict[str, int] = {} for file_hash in all_files: # The schedule might not be in the cache if the submission or file was just issued, # but it will be as soon as the file passes through the dispatcher schedule = dispatch_hash.schedules.get(file_hash) status_values = all_service_status.get(file_hash, {}) # Go through the schedule stage by stage so we can react to drops # either we have a result and we don't need to count the file (but might drop it) # or we don't have a result, and we need to count that file while schedule: stage = schedule.pop(0) for service_name in stage: status = status_values.get(service_name) if status: if status.drop: schedule.clear() else: output[service_name] = output.get(service_name, 0) + 1 return output def request_work(self, worker_id, service_name, service_version, timeout: float = 60, blocking=True) -> Optional[ServiceTask]: """Pull work from the service queue for the service in question. :param service_version: :param worker_id: :param service_name: Which service needs work. :param timeout: How many seconds to block before returning if blocking is true. :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately :return: The job found, and a boolean value indicating if this is the first time this task has been returned by request_work. """ start = time.time() remaining = timeout while int(remaining) > 0: try: return self._request_work(worker_id, service_name, service_version, blocking=blocking, timeout=remaining) except RetryRequestWork: remaining = timeout - (time.time() - start) return None def _request_work(self, worker_id, service_name, service_version, timeout, blocking) -> Optional[ServiceTask]: # For when we recursively retry on bad task dequeue-ing if int(timeout) <= 0: self.log.info( f"{service_name}:{worker_id} no task returned [timeout]") return None # Get work from the queue work_queue = get_service_queue(service_name, self.redis) if blocking: result = work_queue.blocking_pop(timeout=int(timeout)) else: result = work_queue.pop(1) if result: result = result[0] if not result: self.log.info( f"{service_name}:{worker_id} no task returned: [empty message]" ) return None task = ServiceTask(result) # If someone is supposed to be working on this task right now, we won't be able to add it if self.running_tasks.add(task.key(), task.as_primitives()): self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found" ) process_table = DispatchHash(task.sid, self.redis) abandoned = process_table.dispatch_time( file_hash=task.fileinfo.sha256, service=task.service_name) == 0 finished = process_table.finished( file_hash=task.fileinfo.sha256, service=task.service_name) is not None # A service might be re-dispatched as it finishes, when that is the case it can be marked as # both finished and dispatched, if that is the case, drop the dispatch from the table if finished and not abandoned: process_table.drop_dispatch(file_hash=task.fileinfo.sha256, service=task.service_name) if abandoned or finished: self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete" ) self.running_tasks.pop(task.key()) raise RetryRequestWork() # Check if this task has reached the retry limit attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}', host=self.redis) total_attempts = attempt_record.increment(task.key()) self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} " f"task attempt {total_attempts}/3") if total_attempts > 3: self.log.warning( f"[{task.sid}/{task.fileinfo.sha256}] " f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED " ) error = Error( dict( archive_ts=now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60), created='NOW', expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None, response=dict( message= f'The number of retries has passed the limit.', service_name=task.service_name, service_version=service_version, status='FAIL_NONRECOVERABLE', ), sha256=task.fileinfo.sha256, type="TASK PRE-EMPTED", )) error_key = error.build_key(task=task) self.service_failed(task.sid, error_key, error) export_metrics_once(service_name, Metrics, dict(fail_nonrecoverable=1), host=worker_id, counter_type='service') raise RetryRequestWork() # Get the service information service_data = self.service_data[task.service_name] self.timeout_watcher.touch_task(timeout=int(service_data.timeout), key=f'{task.sid}-{task.key()}', worker=worker_id, task_key=task.key()) return task raise RetryRequestWork() def _dispatching_error(self, task, process_table, error): error_key = error.build_key(task=task) if process_table.add_error(error_key): self.errors.save(error_key, error) msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key( sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results." ) return task = ServiceTask(task) # Check if the service is a candidate for dynamic recursion prevention if not task.ignore_dynamic_recursion_prevention: service_info = self.service_data.get(result.response.service_name, None) if service_info and service_info.category == "Dynamic Analysis": # TODO: This should be done in lua because it can introduce race condition in the future # but in the meantime it will remain this way while we can confirm it work as expected submission = self.active_submissions.get(sid) submission['submission']['params']['services'][ 'runtime_excluded'].append(result.response.service_name) self.active_submissions.set(sid, submission) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: with Lock(f"lock-{result_key}", 5, self.redis): old = self.ds.result.get(result_key) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None self.ds.result.save(result_key, result) # Let the logs know we have received a result for this task if result.drop_file: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key} but processing will stop after this service." ) else: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key}") # Store the result object and mark the service finished in the global table process_table = DispatchHash(task.sid, self.redis) remaining, duplicate = process_table.finish( task.fileinfo.sha256, task.service_name, result_key, result.result.score, result.classification, result.drop_file) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') if duplicate: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already " f"completed in the global processing table.") return # Push the result tags into redis new_tags = [] for section in result.result.sections: new_tags.extend(tag_dict_to_list(section.tags.as_primitives())) if new_tags: tag_set = ExpiringSet(get_tag_set_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) tag_set.add(*new_tags) # Update the temporary data table for this file temp_data_hash = ExpiringHash(get_temporary_submission_data_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) for key, value in (temporary_data or {}).items(): temp_data_hash.set(key, value) # Send the extracted files to the dispatcher depth_limit = self.config.submission.max_extraction_depth new_depth = task.depth + 1 if new_depth < depth_limit: # Prepare the temporary data from the parent to build the temporary data table for # these newly extract files parent_data = dict(temp_data_hash.items()) for extracted_data in result.response.extracted: if not process_table.add_file( extracted_data.sha256, task.max_files, parent_hash=task.fileinfo.sha256): if parent_data: child_hash_name = get_temporary_submission_data_name( task.sid, extracted_data.sha256) ExpiringHash(child_hash_name, host=self.redis).multi_set(parent_data) self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"Too many files extracted for submission {task.sid} " f"{extracted_data.sha256} extracted by " f"{task.service_name} will be dropped", 'service_name': task.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX FILES REACHED' })) continue file_data = self.files.get(extracted_data.sha256) self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.max( extracted_data.classification).value, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=new_depth, parent_hash=task.fileinfo.sha256, max_files=task.max_files)).as_primitives()) else: for extracted_data in result.response.extracted: self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"{task.service_name} has extracted a file " f"{extracted_data.sha256} beyond the depth limits", 'service_name': result.response.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX DEPTH REACHED' })) # If the global table said that this was the last outstanding service, # send a message to the dispatchers. if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.value, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) def service_failed(self, sid: str, error_key: str, error: Error): task_key = ServiceTask.make_key( sid=sid, service_name=error.response.service_name, sha=error.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified " f"task in its set of running tasks while processing an error.") return task = ServiceTask(task) self.log.debug( f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error." ) remaining = -1 # Mark the attempt to process the file over in the dispatch table process_table = DispatchHash(task.sid, self.redis) if error.response.status == "FAIL_RECOVERABLE": # Because the error is recoverable, we will not save it nor we will notify the user process_table.fail_recoverable(task.fileinfo.sha256, task.service_name) else: # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user self.errors.save(error_key, error) remaining, _duplicate = process_table.fail_nonrecoverable( task.fileinfo.sha256, task.service_name, error_key) # Send the result key to any watching systems msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') # Send a message to prompt the re-issue of the task if needed if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) def setup_watch_queue(self, sid): """ This function takes a submission ID as a parameter and creates a unique queue where all service result keys for that given submission will be returned to as soon as they come in. If the submission is in the middle of processing, this will also send all currently received keys through the specified queue so the client that requests the watch queue is up to date. :param sid: Submission ID :return: The name of the watch queue that was created """ # Create a unique queue queue_name = reply_queue_name(prefix="D", suffix="WQ") watch_queue = NamedQueue(queue_name, ttl=30) watch_queue.push( WatchQueueMessage({ 'status': 'START' }).as_primitives()) # Add the newly created queue to the list of queues for the given submission self._get_watcher_list(sid).add(queue_name) # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute) # Download the entire status table from redis dispatch_hash = DispatchHash(sid, self.redis) if dispatch_hash.dispatch_count( ) == 0 and dispatch_hash.finished_count() == 0: # This table is empty? do we have this submission at all? submission = self.ds.submission.get(sid) if not submission or submission.state == 'completed': watch_queue.push( WatchQueueMessage({ "status": "STOP" }).as_primitives()) else: # We do have a submission, remind the dispatcher to work on it self.submission_queue.push({'sid': sid}) else: all_service_status = dispatch_hash.all_results() for status_values in all_service_status.values(): for status in status_values.values(): if status.is_error: watch_queue.push( WatchQueueMessage({ "status": "FAIL", "cache_key": status.key }).as_primitives()) else: watch_queue.push( WatchQueueMessage({ "status": "OK", "cache_key": status.key }).as_primitives()) return queue_name def _get_watcher_list(self, sid): return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
class DispatchClient: def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None): self.config = forge.get_config() self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ds = datastore or forge.get_datastore(self.config) self.log = logger or logging.getLogger("assemblyline.dispatching.client") self.results = self.ds.result self.errors = self.ds.error self.files = self.ds.file self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist) self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) self.service_data = cast(Dict[str, Service], CachedObject(self._get_services)) self.dispatcher_data = [] self.dispatcher_data_age = 0.0 self.dead_dispatchers = [] @weak_lru(maxsize=128) def _get_queue_from_cache(self, name): return NamedQueue(name, host=self.redis, ttl=QUEUE_EXPIRY) def _get_services(self): # noinspection PyUnresolvedReferences return {x.name: x for x in self.ds.list_all_services(full=True)} def is_dispatcher(self, dispatcher_id) -> bool: if dispatcher_id in self.dead_dispatchers: return False if time.time() - self.dispatcher_data_age > 120 or dispatcher_id not in self.dispatcher_data: self.dispatcher_data = Dispatcher.all_instances(self.redis_persist) self.dispatcher_data_age = time.time() if dispatcher_id in self.dispatcher_data: return True else: self.dead_dispatchers.append(dispatcher_id) return False def dispatch_bundle(self, submission: Submission, results: Dict[str, Result], file_infos: Dict[str, File], file_tree, errors: Dict[str, Error], completed_queue: str = None): """Insert a bundle into the dispatching system and continue scanning of its files Prerequisites: - Submission, results, file_infos and errors should already be saved in the datastore - Files should already be in the filestore """ self.submission_queue.push(dict( submission=submission.as_primitives(), results=results, file_infos=file_infos, file_tree=file_tree, errors=errors, completed_queue=completed_queue, )) def dispatch_submission(self, submission: Submission, completed_queue: str = None): """Insert a submission into the dispatching system. Note: You probably actually want to use the SubmissionTool Prerequsits: - submission should already be saved in the datastore - files should already be in the datastore and filestore """ self.submission_queue.push(dict( submission=submission.as_primitives(), completed_queue=completed_queue, )) def outstanding_services(self, sid) -> Dict[str, int]: """ List outstanding services for a given submission and the number of file each of them still have to process. :param sid: Submission ID :return: Dictionary of services and number of files remaining per services e.g. {"SERVICE_NAME": 1, ... } """ dispatcher_id = self.submission_assignments.get(sid) if dispatcher_id: queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue") queue = NamedQueue(queue_name, host=self.redis, ttl=30) command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis) command_queue.push(DispatcherCommandMessage({ 'kind': LIST_OUTSTANDING, 'payload_data': ListOutstanding({ 'response_queue': queue_name, 'submission': sid }) }).as_primitives()) return queue.pop(timeout=30) return {} @elasticapm.capture_span(span_type='dispatch_client') def request_work(self, worker_id, service_name, service_version, timeout: float = 60, blocking=True, low_priority=False) -> Optional[ServiceTask]: """Pull work from the service queue for the service in question. :param service_version: :param worker_id: :param service_name: Which service needs work. :param timeout: How many seconds to block before returning if blocking is true. :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately :return: The job found, and a boolean value indicating if this is the first time this task has been returned by request_work. """ start = time.time() remaining = timeout while int(remaining) > 0: work = self._request_work(worker_id, service_name, service_version, blocking=blocking, timeout=remaining, low_priority=low_priority) if work or not blocking: return work remaining = timeout - (time.time() - start) return None def _request_work(self, worker_id, service_name, service_version, timeout, blocking, low_priority=False) -> Optional[ServiceTask]: # For when we recursively retry on bad task dequeue-ing if int(timeout) <= 0: self.log.info(f"{service_name}:{worker_id} no task returned [timeout]") return None # Get work from the queue work_queue = get_service_queue(service_name, self.redis) if blocking: result = work_queue.blocking_pop(timeout=int(timeout), low_priority=low_priority) else: if low_priority: result = work_queue.unpush(1) else: result = work_queue.pop(1) if result: result = result[0] if not result: self.log.info(f"{service_name}:{worker_id} no task returned: [empty message]") return None task = ServiceTask(result) task.metadata['worker__'] = worker_id dispatcher = task.metadata['dispatcher__'] if not self.is_dispatcher(dispatcher): self.log.info(f"{service_name}:{worker_id} no task returned: [task from dead dispatcher]") return None if self.running_tasks.add(task.key(), task.as_primitives()): self.log.info(f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found") start_queue = self._get_queue_from_cache(DISPATCH_START_EVENTS + dispatcher) start_queue.push((task.sid, task.fileinfo.sha256, service_name, worker_id)) return task return None @elasticapm.capture_span(span_type='dispatch_client') def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key(sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning(f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results.") return task = ServiceTask(task) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: while True: old, version = self.ds.result.get_if_exists( result_key, archive_access=self.config.datastore.ilm.update_archive, version=True) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None try: self.ds.result.save(result_key, result, version=version) break except VersionConflictException as vce: self.log.info(f"Retrying to save results due to version conflict: {str(vce)}") # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w, host=self.redis).push(msg) # Save the tags tags = [] for section in result.result.sections: tags.extend(tag_dict_to_list(flatten(section.tags.as_primitives()))) # Pull out file names if we have them file_names = {} for extracted_data in result.response.extracted: if extracted_data.name: file_names[extracted_data.sha256] = extracted_data.name # dispatcher = task.metadata['dispatcher__'] result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher) ex_ts = result.expiry_ts.strftime(DATEFORMAT) if result.expiry_ts else result.archive_ts.strftime(DATEFORMAT) result_queue.push({ # 'service_task': task.as_primitives(), # 'result': result.as_primitives(), 'sid': task.sid, 'sha256': result.sha256, 'service_name': task.service_name, 'service_version': result.response.service_version, 'service_tool_version': result.response.service_tool_version, 'archive_ts': result.archive_ts.strftime(DATEFORMAT), 'expiry_ts': ex_ts, 'result_summary': { 'key': result_key, 'drop': result.drop_file, 'score': result.result.score, 'children': [r.sha256 for r in result.response.extracted], }, 'tags': tags, 'extracted_names': file_names, 'temporary_data': temporary_data }) @elasticapm.capture_span(span_type='dispatch_client') def service_failed(self, sid: str, error_key: str, error: Error): task_key = ServiceTask.make_key(sid=sid, service_name=error.response.service_name, sha=error.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning(f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified " f"task in its set of running tasks while processing an error.") return task = ServiceTask(task) self.log.debug(f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error.") if error.response.status == "FAIL_NONRECOVERABLE": # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user self.errors.save(error_key, error) # Send the result key to any watching systems msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w, host=self.redis).push(msg) dispatcher = task.metadata['dispatcher__'] result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher) result_queue.push({ 'sid': task.sid, 'service_task': task.as_primitives(), 'error': error.as_primitives(), 'error_key': error_key }) def setup_watch_queue(self, sid: str) -> Optional[str]: """ This function takes a submission ID as a parameter and creates a unique queue where all service result keys for that given submission will be returned to as soon as they come in. If the submission is in the middle of processing, this will also send all currently received keys through the specified queue so the client that requests the watch queue is up to date. :param sid: Submission ID :return: The name of the watch queue that was created """ dispatcher_id = self.submission_assignments.get(sid) if dispatcher_id: queue_name = reply_queue_name(prefix="D", suffix="WQ") command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis) command_queue.push(DispatcherCommandMessage({ 'kind': CREATE_WATCH, 'payload_data': CreateWatch({ 'queue_name': queue_name, 'submission': sid }) }).as_primitives()) return queue_name def _get_watcher_list(self, sid): return ExpiringSet(make_watcher_list_name(sid), host=self.redis)