class WatcherClient: def __init__(self, redis_persist=None): config = forge.get_config() self.redis = redis_persist or get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, ) self.hash = ExpiringHash(name=WATCHER_HASH, ttl=MAX_TIMEOUT, host=redis_persist) self.queue = UniquePriorityQueue(WATCHER_QUEUE, redis_persist) def touch(self, timeout: int, key: str, queue: str, message: dict): if timeout >= MAX_TIMEOUT: raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}") self.hash.set(key, { 'action': WatcherAction.Message, 'queue': queue, 'message': message }) seconds, _ = retry_call(self.redis.time) self.queue.push(int(seconds + timeout), key) def touch_task(self, timeout: int, key: str, worker: str, task_key: str): if timeout >= MAX_TIMEOUT: raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}") self.hash.set( key, { 'action': WatcherAction.TimeoutTask, 'worker': worker, 'task_key': task_key }) seconds, _ = retry_call(self.redis.time) self.queue.push(int(seconds + timeout), key) def clear(self, key: str): self.queue.remove(key) self.hash.pop(key)
class TaskingClient: """A helper class to simplify tasking for privileged services and service-server. This tool helps take care of interactions between the filestore, datastore, dispatcher, and any sources of files to be processed. """ def __init__(self, datastore: AssemblylineDatastore = None, filestore: FileStore = None, config=None, redis=None, redis_persist=None, identify=None): self.log = logging.getLogger('assemblyline.tasking_client') self.config = config or forge.CachedObject(forge.get_config) self.datastore = datastore or forge.get_datastore(self.config) self.dispatch_client = DispatchClient(self.datastore, redis=redis, redis_persist=redis_persist) self.event_sender = EventSender('changes.services', redis) self.filestore = filestore or forge.get_filestore(self.config) self.heuristic_handler = HeuristicHandler(self.datastore) self.heuristics = { h.heur_id: h for h in self.datastore.list_all_heuristics() } self.status_table = ExpiringHash(SERVICE_STATE_HASH, ttl=60 * 30, host=redis) self.tag_safelister = forge.CachedObject(forge.get_tag_safelister, kwargs=dict( log=self.log, config=config, datastore=self.datastore), refresh=300) if identify: self.cleanup = False else: self.cleanup = True self.identify = identify or forge.get_identify( config=self.config, datastore=self.datastore, use_cache=True) def __enter__(self): return self def __exit__(self, *_): self.stop() def stop(self): if self.cleanup: self.identify.stop() @elasticapm.capture_span(span_type='tasking_client') def upload_file(self, file_path, classification, ttl, is_section_image, expected_sha256=None): # Identify the file info of the uploaded file file_info = self.identify.fileinfo(file_path) # Validate SHA256 of the uploaded file if expected_sha256 is None or expected_sha256 == file_info['sha256']: file_info['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) file_info['classification'] = classification if ttl: file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60) else: file_info['expiry_ts'] = None # Update the datastore with the uploaded file self.datastore.save_or_freshen_file( file_info['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=is_section_image) # Upload file to the filestore (upload already checks if the file exists) self.filestore.upload(file_path, file_info['sha256']) else: raise TaskingClientException( "Uploaded file does not match expected file hash. " f"[{file_info['sha256']} != {expected_sha256}]") # Service @elasticapm.capture_span(span_type='tasking_client') def register_service(self, service_data, log_prefix=""): keep_alive = True try: # Get heuristics list heuristics = service_data.pop('heuristics', None) # Patch update_channel, registry_type before Service registration object creation service_data['update_channel'] = service_data.get( 'update_channel', self.config.services.preferred_update_channel) service_data['docker_config']['registry_type'] = service_data['docker_config'] \ .get('registry_type', self.config.services.preferred_registry_type) service_data['privileged'] = service_data.get( 'privileged', self.config.services.prefer_service_privileged) for dep in service_data.get('dependencies', {}).values(): dep['container']['registry_type'] = dep.get( 'registry_type', self.config.services.preferred_registry_type) # Pop unused registration service_data for x in ['file_required', 'tool_version']: service_data.pop(x, None) # Create Service registration object service = Service(service_data) # Fix service version, we don't need to see the stable label service.version = service.version.replace('stable', '') # Save service if it doesn't already exist if not self.datastore.service.exists( f'{service.name}_{service.version}'): self.datastore.service.save( f'{service.name}_{service.version}', service) self.datastore.service.commit() self.log.info(f"{log_prefix}{service.name} registered") keep_alive = False # Save service delta if it doesn't already exist if not self.datastore.service_delta.exists(service.name): self.datastore.service_delta.save(service.name, {'version': service.version}) self.datastore.service_delta.commit() self.log.info(f"{log_prefix}{service.name} " f"version ({service.version}) registered") new_heuristics = [] if heuristics: plan = self.datastore.heuristic.get_bulk_plan() for index, heuristic in enumerate(heuristics): heuristic_id = f'#{index}' # Set heuristic id to it's position in the list for logging purposes try: # Append service name to heuristic ID heuristic[ 'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}" # Attack_id field is now a list, make it a list if we receive otherwise attack_id = heuristic.get('attack_id', None) if isinstance(attack_id, str): heuristic['attack_id'] = [attack_id] heuristic = Heuristic(heuristic) heuristic_id = heuristic.heur_id plan.add_upsert_operation(heuristic_id, heuristic) except Exception as e: msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}" self.log.exception(f"{log_prefix}{msg}") raise ValueError(msg) for item in self.datastore.heuristic.bulk(plan)['items']: if item['update']['result'] != "noop": new_heuristics.append(item['update']['_id']) self.log.info( f"{log_prefix}{service.name} " f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}" ) self.datastore.heuristic.commit() service_config = self.datastore.get_service_with_delta( service.name, as_obj=False) # Notify components watching for service config changes self.event_sender.send(service.name, { 'operation': Operation.Added, 'name': service.name }) except ValueError as e: # Catch errors when building Service or Heuristic model(s) raise e return dict(keep_alive=keep_alive, new_heuristics=new_heuristics, service_config=service_config or dict()) # Task @elasticapm.capture_span(span_type='tasking_client') def get_task(self, client_id, service_name, service_version, service_tool_version, metric_factory, status_expiry=None, timeout=30): if status_expiry is None: status_expiry = time.time() + timeout cache_found = False try: service_data = self.dispatch_client.service_data[service_name] except KeyError: raise ServiceMissingException( "The service you're asking task for does not exist, try later", 404) # Set the service status to Idle since we will be waiting for a task self.status_table.set( client_id, (service_name, ServiceStatus.Idle, status_expiry)) # Getting a new task task = self.dispatch_client.request_work(client_id, service_name, service_version, timeout=timeout) if not task: # We've reached the timeout and no task found in service queue return None, False # We've got a task to process, consider us busy self.status_table.set(client_id, (service_name, ServiceStatus.Running, time.time() + service_data.timeout)) metric_factory.increment('execute') result_key = Result.help_build_key( sha256=task.fileinfo.sha256, service_name=service_name, service_version=service_version, service_tool_version=service_tool_version, is_empty=False, task=task) # If we are allowed, try to see if the result has been cached if not task.ignore_cache and not service_data.disable_cache: # Checking for previous results for this key result = self.datastore.result.get_if_exists(result_key) if result: metric_factory.increment('cache_hit') if result.result.score: metric_factory.increment('scored') else: metric_factory.increment('not_scored') result.archive_ts = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) self.dispatch_client.service_finished(task.sid, result_key, result) cache_found = True if not cache_found: # Checking for previous empty results for this key result = self.datastore.emptyresult.get_if_exists( f"{result_key}.e") if result: metric_factory.increment('cache_hit') metric_factory.increment('not_scored') result = self.datastore.create_empty_result_from_key( result_key) self.dispatch_client.service_finished( task.sid, f"{result_key}.e", result) cache_found = True if not cache_found: metric_factory.increment('cache_miss') else: metric_factory.increment('cache_skipped') if not cache_found: # No luck with the cache, lets dispatch the task to a client return task.as_primitives(), False return None, True @elasticapm.capture_span(span_type='tasking_client') def task_finished(self, service_task, client_id, service_name, metric_factory): exec_time = service_task.get('exec_time') try: task = ServiceTask(service_task['task']) if 'result' in service_task: # Task created a result missing_files = self._handle_task_result( exec_time, task, service_task['result'], client_id, service_name, service_task['freshen'], metric_factory) if missing_files: return dict(success=False, missing_files=missing_files) return dict(success=True) elif 'error' in service_task: # Task created an error error = service_task['error'] self._handle_task_error(exec_time, task, error, client_id, service_name, metric_factory) return dict(success=True) else: return None except ValueError as e: # Catch errors when building Task or Result model raise e @elasticapm.capture_span(span_type='tasking_client') def _handle_task_result(self, exec_time: int, task: ServiceTask, result: Dict[str, Any], client_id, service_name, freshen: bool, metric_factory): def freshen_file(file_info_list, item): file_info = file_info_list.get(item['sha256'], None) if file_info is None or not self.filestore.exists(item['sha256']): return True else: file_info['archive_ts'] = archive_ts file_info['expiry_ts'] = expiry_ts file_info['classification'] = item['classification'] self.datastore.save_or_freshen_file( item['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=item.get('is_section_image', False)) return False archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) else: expiry_ts = None # Check if all files are in the filestore if freshen: missing_files = [] hashes = list( set([ f['sha256'] for f in result['response']['extracted'] + result['response']['supplementary'] ])) file_infos = self.datastore.file.multiget(hashes, as_obj=False, error_on_missing=False) with elasticapm.capture_span( name="handle_task_result.freshen_files", span_type="tasking_client"): with concurrent.futures.ThreadPoolExecutor( max_workers=5) as executor: res = { f['sha256']: executor.submit(freshen_file, file_infos, f) for f in result['response']['extracted'] + result['response']['supplementary'] } for k, v in res.items(): if v.result(): missing_files.append(k) if missing_files: return missing_files # Add scores to the heuristics, if any section set a heuristic with elasticapm.capture_span( name="handle_task_result.process_heuristics", span_type="tasking_client"): total_score = 0 for section in result['result']['sections']: zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True) section['tags'] = flatten(section['tags']) if section.get('heuristic'): heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}" section['heuristic']['heur_id'] = heur_id try: section[ 'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic( section['heuristic'], self.heuristics, zeroize_on_sig_safe) for tag in new_tags: section['tags'].setdefault(tag[0], []) if tag[1] not in section['tags'][tag[0]]: section['tags'][tag[0]].append(tag[1]) total_score += section['heuristic']['score'] except InvalidHeuristicException: section['heuristic'] = None # Update the total score of the result result['result']['score'] = total_score # Add timestamps for creation, archive and expiry result['created'] = now_as_iso() result['archive_ts'] = archive_ts result['expiry_ts'] = expiry_ts # Pop the temporary submission data temp_submission_data = result.pop('temp_submission_data', None) if temp_submission_data: old_submission_data = { row.name: row.value for row in task.temporary_submission_data } temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in old_submission_data or v != old_submission_data[k] } big_temp_data = { k: len(str(v)) for k, v in temp_submission_data.items() if len(str(v)) > self.config.submission.max_temp_data_length } if big_temp_data: big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()] self.log.warning( f"[{task.sid}] The following temporary submission keys where ignored because they are " "bigger then the maximum data size allowed " f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}" ) temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in big_temp_data } # Process the tag values with elasticapm.capture_span(name="handle_task_result.process_tags", span_type="tasking_client"): for section in result['result']['sections']: # Perform tag safelisting tags, safelisted_tags = self.tag_safelister.get_validated_tag_map( section['tags']) section['tags'] = unflatten(tags) section['safelisted_tags'] = safelisted_tags section['tags'], dropped = construct_safe( Tagging, section.get('tags', {})) # Set section score to zero and lower total score if service is set to zeroize score # and all tags were safelisted if section.pop('zeroize_on_tag_safe', False) and \ section.get('heuristic') and \ len(tags) == 0 and \ len(safelisted_tags) != 0: result['result']['score'] -= section['heuristic']['score'] section['heuristic']['score'] = 0 if dropped: self.log.warning( f"[{task.sid}] Invalid tag data from {service_name}: {dropped}" ) result = Result(result) result_key = result.build_key( service_tool_version=result.response.service_tool_version, task=task) self.dispatch_client.service_finished(task.sid, result_key, result, temp_submission_data) # Metrics if result.result.score > 0: metric_factory.increment('scored') else: metric_factory.increment('not_scored') self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}" ) self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5)) @elasticapm.capture_span(span_type='tasking_client') def _handle_task_error(self, exec_time: int, task: ServiceTask, error: Dict[str, Any], client_id, service_name, metric_factory) -> None: self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}" ) # Add timestamps for creation, archive and expiry error['created'] = now_as_iso() error['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60) error = Error(error) error_key = error.build_key( service_tool_version=error.response.service_tool_version, task=task) self.dispatch_client.service_failed(task.sid, error_key, error) # Metrics if error.response.status == 'FAIL_RECOVERABLE': metric_factory.increment('fail_recoverable') else: metric_factory.increment('fail_nonrecoverable') self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
class DispatchClient: def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None): self.config = forge.get_config() self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.timeout_watcher = WatcherClient(redis_persist) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.file_queue = NamedQueue(FILE_QUEUE, self.redis) self.ds = datastore or forge.get_datastore(self.config) self.log = logger or logging.getLogger( "assemblyline.dispatching.client") self.results = datastore.result self.errors = datastore.error self.files = datastore.file self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=redis_persist) self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) self.service_data = cast(Dict[str, Service], CachedObject(self._get_services)) def _get_services(self): # noinspection PyUnresolvedReferences return {x.name: x for x in self.ds.list_all_services(full=True)} def dispatch_submission(self, submission: Submission, completed_queue: str = None): """Insert a submission into the dispatching system. Note: You probably actually want to use the SubmissionTool Prerequsits: - submission should already be saved in the datastore - files should already be in the datastore and filestore """ self.submission_queue.push( SubmissionTask( dict( submission=submission, completed_queue=completed_queue, )).as_primitives()) def outstanding_services(self, sid) -> Dict[str, int]: """ List outstanding services for a given submission and the number of file each of them still have to process. :param sid: Submission ID :return: Dictionary of services and number of files remaining per services e.g. {"SERVICE_NAME": 1, ... } """ # Download the entire status table from redis dispatch_hash = DispatchHash(sid, self.redis) all_service_status = dispatch_hash.all_results() all_files = dispatch_hash.all_files() self.log.info( f"[{sid}] Listing outstanding services {len(all_files)} files " f"and {len(all_service_status)} entries found") output: Dict[str, int] = {} for file_hash in all_files: # The schedule might not be in the cache if the submission or file was just issued, # but it will be as soon as the file passes through the dispatcher schedule = dispatch_hash.schedules.get(file_hash) status_values = all_service_status.get(file_hash, {}) # Go through the schedule stage by stage so we can react to drops # either we have a result and we don't need to count the file (but might drop it) # or we don't have a result, and we need to count that file while schedule: stage = schedule.pop(0) for service_name in stage: status = status_values.get(service_name) if status: if status.drop: schedule.clear() else: output[service_name] = output.get(service_name, 0) + 1 return output def request_work(self, worker_id, service_name, service_version, timeout: float = 60, blocking=True) -> Optional[ServiceTask]: """Pull work from the service queue for the service in question. :param service_version: :param worker_id: :param service_name: Which service needs work. :param timeout: How many seconds to block before returning if blocking is true. :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately :return: The job found, and a boolean value indicating if this is the first time this task has been returned by request_work. """ start = time.time() remaining = timeout while int(remaining) > 0: try: return self._request_work(worker_id, service_name, service_version, blocking=blocking, timeout=remaining) except RetryRequestWork: remaining = timeout - (time.time() - start) return None def _request_work(self, worker_id, service_name, service_version, timeout, blocking) -> Optional[ServiceTask]: # For when we recursively retry on bad task dequeue-ing if int(timeout) <= 0: self.log.info( f"{service_name}:{worker_id} no task returned [timeout]") return None # Get work from the queue work_queue = get_service_queue(service_name, self.redis) if blocking: result = work_queue.blocking_pop(timeout=int(timeout)) else: result = work_queue.pop(1) if result: result = result[0] if not result: self.log.info( f"{service_name}:{worker_id} no task returned: [empty message]" ) return None task = ServiceTask(result) # If someone is supposed to be working on this task right now, we won't be able to add it if self.running_tasks.add(task.key(), task.as_primitives()): self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found" ) process_table = DispatchHash(task.sid, self.redis) abandoned = process_table.dispatch_time( file_hash=task.fileinfo.sha256, service=task.service_name) == 0 finished = process_table.finished( file_hash=task.fileinfo.sha256, service=task.service_name) is not None # A service might be re-dispatched as it finishes, when that is the case it can be marked as # both finished and dispatched, if that is the case, drop the dispatch from the table if finished and not abandoned: process_table.drop_dispatch(file_hash=task.fileinfo.sha256, service=task.service_name) if abandoned or finished: self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete" ) self.running_tasks.pop(task.key()) raise RetryRequestWork() # Check if this task has reached the retry limit attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}', host=self.redis) total_attempts = attempt_record.increment(task.key()) self.log.info( f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} " f"task attempt {total_attempts}/3") if total_attempts > 3: self.log.warning( f"[{task.sid}/{task.fileinfo.sha256}] " f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED " ) error = Error( dict( archive_ts=now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60), created='NOW', expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None, response=dict( message= f'The number of retries has passed the limit.', service_name=task.service_name, service_version=service_version, status='FAIL_NONRECOVERABLE', ), sha256=task.fileinfo.sha256, type="TASK PRE-EMPTED", )) error_key = error.build_key(task=task) self.service_failed(task.sid, error_key, error) export_metrics_once(service_name, Metrics, dict(fail_nonrecoverable=1), host=worker_id, counter_type='service') raise RetryRequestWork() # Get the service information service_data = self.service_data[task.service_name] self.timeout_watcher.touch_task(timeout=int(service_data.timeout), key=f'{task.sid}-{task.key()}', worker=worker_id, task_key=task.key()) return task raise RetryRequestWork() def _dispatching_error(self, task, process_table, error): error_key = error.build_key(task=task) if process_table.add_error(error_key): self.errors.save(error_key, error) msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key( sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results." ) return task = ServiceTask(task) # Check if the service is a candidate for dynamic recursion prevention if not task.ignore_dynamic_recursion_prevention: service_info = self.service_data.get(result.response.service_name, None) if service_info and service_info.category == "Dynamic Analysis": # TODO: This should be done in lua because it can introduce race condition in the future # but in the meantime it will remain this way while we can confirm it work as expected submission = self.active_submissions.get(sid) submission['submission']['params']['services'][ 'runtime_excluded'].append(result.response.service_name) self.active_submissions.set(sid, submission) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: with Lock(f"lock-{result_key}", 5, self.redis): old = self.ds.result.get(result_key) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None self.ds.result.save(result_key, result) # Let the logs know we have received a result for this task if result.drop_file: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key} but processing will stop after this service." ) else: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key}") # Store the result object and mark the service finished in the global table process_table = DispatchHash(task.sid, self.redis) remaining, duplicate = process_table.finish( task.fileinfo.sha256, task.service_name, result_key, result.result.score, result.classification, result.drop_file) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') if duplicate: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already " f"completed in the global processing table.") return # Push the result tags into redis new_tags = [] for section in result.result.sections: new_tags.extend(tag_dict_to_list(section.tags.as_primitives())) if new_tags: tag_set = ExpiringSet(get_tag_set_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) tag_set.add(*new_tags) # Update the temporary data table for this file temp_data_hash = ExpiringHash(get_temporary_submission_data_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) for key, value in (temporary_data or {}).items(): temp_data_hash.set(key, value) # Send the extracted files to the dispatcher depth_limit = self.config.submission.max_extraction_depth new_depth = task.depth + 1 if new_depth < depth_limit: # Prepare the temporary data from the parent to build the temporary data table for # these newly extract files parent_data = dict(temp_data_hash.items()) for extracted_data in result.response.extracted: if not process_table.add_file( extracted_data.sha256, task.max_files, parent_hash=task.fileinfo.sha256): if parent_data: child_hash_name = get_temporary_submission_data_name( task.sid, extracted_data.sha256) ExpiringHash(child_hash_name, host=self.redis).multi_set(parent_data) self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"Too many files extracted for submission {task.sid} " f"{extracted_data.sha256} extracted by " f"{task.service_name} will be dropped", 'service_name': task.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX FILES REACHED' })) continue file_data = self.files.get(extracted_data.sha256) self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.max( extracted_data.classification).value, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=new_depth, parent_hash=task.fileinfo.sha256, max_files=task.max_files)).as_primitives()) else: for extracted_data in result.response.extracted: self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"{task.service_name} has extracted a file " f"{extracted_data.sha256} beyond the depth limits", 'service_name': result.response.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX DEPTH REACHED' })) # If the global table said that this was the last outstanding service, # send a message to the dispatchers. if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.value, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) def service_failed(self, sid: str, error_key: str, error: Error): task_key = ServiceTask.make_key( sid=sid, service_name=error.response.service_name, sha=error.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified " f"task in its set of running tasks while processing an error.") return task = ServiceTask(task) self.log.debug( f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error." ) remaining = -1 # Mark the attempt to process the file over in the dispatch table process_table = DispatchHash(task.sid, self.redis) if error.response.status == "FAIL_RECOVERABLE": # Because the error is recoverable, we will not save it nor we will notify the user process_table.fail_recoverable(task.fileinfo.sha256, task.service_name) else: # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user self.errors.save(error_key, error) remaining, _duplicate = process_table.fail_nonrecoverable( task.fileinfo.sha256, task.service_name, error_key) # Send the result key to any watching systems msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') # Send a message to prompt the re-issue of the task if needed if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) def setup_watch_queue(self, sid): """ This function takes a submission ID as a parameter and creates a unique queue where all service result keys for that given submission will be returned to as soon as they come in. If the submission is in the middle of processing, this will also send all currently received keys through the specified queue so the client that requests the watch queue is up to date. :param sid: Submission ID :return: The name of the watch queue that was created """ # Create a unique queue queue_name = reply_queue_name(prefix="D", suffix="WQ") watch_queue = NamedQueue(queue_name, ttl=30) watch_queue.push( WatchQueueMessage({ 'status': 'START' }).as_primitives()) # Add the newly created queue to the list of queues for the given submission self._get_watcher_list(sid).add(queue_name) # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute) # Download the entire status table from redis dispatch_hash = DispatchHash(sid, self.redis) if dispatch_hash.dispatch_count( ) == 0 and dispatch_hash.finished_count() == 0: # This table is empty? do we have this submission at all? submission = self.ds.submission.get(sid) if not submission or submission.state == 'completed': watch_queue.push( WatchQueueMessage({ "status": "STOP" }).as_primitives()) else: # We do have a submission, remind the dispatcher to work on it self.submission_queue.push({'sid': sid}) else: all_service_status = dispatch_hash.all_results() for status_values in all_service_status.values(): for status in status_values.values(): if status.is_error: watch_queue.push( WatchQueueMessage({ "status": "FAIL", "cache_key": status.key }).as_primitives()) else: watch_queue.push( WatchQueueMessage({ "status": "OK", "cache_key": status.key }).as_primitives()) return queue_name def _get_watcher_list(self, sid): return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key( sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results." ) return task = ServiceTask(task) # Check if the service is a candidate for dynamic recursion prevention if not task.ignore_dynamic_recursion_prevention: service_info = self.service_data.get(result.response.service_name, None) if service_info and service_info.category == "Dynamic Analysis": # TODO: This should be done in lua because it can introduce race condition in the future # but in the meantime it will remain this way while we can confirm it work as expected submission = self.active_submissions.get(sid) submission['submission']['params']['services'][ 'runtime_excluded'].append(result.response.service_name) self.active_submissions.set(sid, submission) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: with Lock(f"lock-{result_key}", 5, self.redis): old = self.ds.result.get(result_key) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None self.ds.result.save(result_key, result) # Let the logs know we have received a result for this task if result.drop_file: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key} but processing will stop after this service." ) else: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key}") # Store the result object and mark the service finished in the global table process_table = DispatchHash(task.sid, self.redis) remaining, duplicate = process_table.finish( task.fileinfo.sha256, task.service_name, result_key, result.result.score, result.classification, result.drop_file) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') if duplicate: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already " f"completed in the global processing table.") return # Push the result tags into redis new_tags = [] for section in result.result.sections: new_tags.extend(tag_dict_to_list(section.tags.as_primitives())) if new_tags: tag_set = ExpiringSet(get_tag_set_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) tag_set.add(*new_tags) # Update the temporary data table for this file temp_data_hash = ExpiringHash(get_temporary_submission_data_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) for key, value in (temporary_data or {}).items(): temp_data_hash.set(key, value) # Send the extracted files to the dispatcher depth_limit = self.config.submission.max_extraction_depth new_depth = task.depth + 1 if new_depth < depth_limit: # Prepare the temporary data from the parent to build the temporary data table for # these newly extract files parent_data = dict(temp_data_hash.items()) for extracted_data in result.response.extracted: if not process_table.add_file( extracted_data.sha256, task.max_files, parent_hash=task.fileinfo.sha256): if parent_data: child_hash_name = get_temporary_submission_data_name( task.sid, extracted_data.sha256) ExpiringHash(child_hash_name, host=self.redis).multi_set(parent_data) self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"Too many files extracted for submission {task.sid} " f"{extracted_data.sha256} extracted by " f"{task.service_name} will be dropped", 'service_name': task.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX FILES REACHED' })) continue file_data = self.files.get(extracted_data.sha256) self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.max( extracted_data.classification).value, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=new_depth, parent_hash=task.fileinfo.sha256, max_files=task.max_files)).as_primitives()) else: for extracted_data in result.response.extracted: self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"{task.service_name} has extracted a file " f"{extracted_data.sha256} beyond the depth limits", 'service_name': result.response.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX DEPTH REACHED' })) # If the global table said that this was the last outstanding service, # send a message to the dispatchers. if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.value, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg)
class RandomService(ServerBase): """Replaces everything past the dispatcher. Including service API, in the future probably include that in this test. """ def __init__(self, datastore=None, filestore=None): super().__init__('assemblyline.randomservice') self.config = forge.get_config() self.datastore = datastore or forge.get_datastore() self.filestore = filestore or forge.get_filestore() self.client_id = get_random_id() self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60) self.counters = { n: MetricsFactory('service', Metrics, name=n, config=self.config) for n in self.datastore.service_delta.keys() } self.queues = [ forge.get_service_queue(name) for name in self.datastore.service_delta.keys() ] self.dispatch_client = DispatchClient(self.datastore) self.service_info = CachedObject(self.datastore.list_all_services, kwargs={'as_obj': False}) def run(self): self.log.info("Random service result generator ready!") self.log.info("Monitoring queues:") for q in self.queues: self.log.info(f"\t{q.name}") self.log.info("Waiting for messages...") while self.running: # Reset Idle flags for s in self.service_info: if s['enabled']: self.service_state_hash.set( f"{self.client_id}_{s['name']}", (s['name'], ServiceStatus.Idle, time.time() + 30 + 5)) message = select(*self.queues, timeout=1) if not message: continue archive_ts = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if self.config.submission.dtl: expiry_ts = now_as_iso(self.config.submission.dtl * 24 * 60 * 60) else: expiry_ts = None queue, msg = message task = ServiceTask(msg) if not self.dispatch_client.running_tasks.add( task.key(), task.as_primitives()): continue # Set service busy flag self.service_state_hash.set( f"{self.client_id}_{task.service_name}", (task.service_name, ServiceStatus.Running, time.time() + 30 + 5)) # METRICS self.counters[task.service_name].increment('execute') # METRICS (not caching here so always miss) self.counters[task.service_name].increment('cache_miss') self.log.info( f"\tQueue {queue} received a new task for sid {task.sid}.") action = random.randint(1, 10) if action >= 2: if action > 8: result = random_minimal_obj(Result) else: result = random_model_obj(Result) result.sha256 = task.fileinfo.sha256 result.response.service_name = task.service_name result.archive_ts = archive_ts result.expiry_ts = expiry_ts result.response.extracted = result.response.extracted[task. depth + 2:] result.response.supplementary = result.response.supplementary[ task.depth + 2:] result_key = Result.help_build_key( sha256=task.fileinfo.sha256, service_name=task.service_name, service_version='0', is_empty=result.is_empty()) self.log.info( f"\t\tA result was generated for this task: {result_key}") new_files = result.response.extracted + result.response.supplementary for f in new_files: if not self.datastore.file.get(f.sha256): random_file = random_model_obj(File) random_file.archive_ts = archive_ts random_file.expiry_ts = expiry_ts random_file.sha256 = f.sha256 self.datastore.file.save(f.sha256, random_file) if not self.filestore.exists(f.sha256): self.filestore.put(f.sha256, f.sha256) time.sleep(random.randint(0, 2)) self.dispatch_client.service_finished(task.sid, result_key, result) # METRICS if result.result.score > 0: self.counters[task.service_name].increment('scored') else: self.counters[task.service_name].increment('not_scored') else: error = random_model_obj(Error) error.archive_ts = archive_ts error.expiry_ts = expiry_ts error.sha256 = task.fileinfo.sha256 error.response.service_name = task.service_name error.type = random.choice( ["EXCEPTION", "SERVICE DOWN", "SERVICE BUSY"]) error_key = error.build_key('0') self.log.info( f"\t\tA {error.response.status}:{error.type} " f"error was generated for this task: {error_key}") self.dispatch_client.service_failed(task.sid, error_key, error) # METRICS if error.response.status == "FAIL_RECOVERABLE": self.counters[task.service_name].increment( 'fail_recoverable') else: self.counters[task.service_name].increment( 'fail_nonrecoverable')