def test_dispatch_extracted(clean_redis, clean_datastore): redis = clean_redis ds = clean_datastore # def service_queue(name): return get_service_queue(name, redis) # Setup the fake datastore file_hash = get_random_hash(64) second_file_hash = get_random_hash(64) for fh in [file_hash, second_file_hash]: obj = random_model_obj(models.file.File) obj.sha256 = fh ds.file.save(fh, obj) # Inject the fake submission submission = random_model_obj(models.submission.Submission) submission.files = [dict(name='./file', sha256=file_hash)] sid = submission.sid = 'first-submission' disp = Dispatcher(ds, redis, redis) disp.running = ToggleTrue() client = DispatchClient(ds, redis, redis) client.dispatcher_data_age = time.time() client.dispatcher_data.append(disp.instance_id) # Launch the submission client.dispatch_submission(submission) disp.pull_submissions() disp.service_worker(disp.process_queue_index(sid)) # Finish one service extracting a file job = client.request_work('0', 'extract', '0') assert job.fileinfo.sha256 == file_hash assert job.filename == './file' new_result: Result = random_minimal_obj(Result) new_result.sha256 = file_hash new_result.response.service_name = 'extract' new_result.response.extracted = [ dict(sha256=second_file_hash, name='second-*', description='abc', classification='U') ] client.service_finished(sid, 'extracted-done', new_result) # process the result disp.pull_service_results() disp.service_worker(disp.process_queue_index(sid)) disp.service_worker(disp.process_queue_index(sid)) # job = client.request_work('0', 'extract', '0') assert job.fileinfo.sha256 == second_file_hash assert job.filename == 'second-*'
class Plumber(CoreBase): def __init__(self, logger=None, shutdown_timeout: float = None, config=None, redis=None, redis_persist=None, datastore=None, delay=60): super().__init__('plumber', logger, shutdown_timeout, config=config, redis=redis, redis_persist=redis_persist, datastore=datastore) self.delay = float(delay) self.dispatch_client = DispatchClient(datastore=self.datastore, redis=self.redis, redis_persist=self.redis_persist, logger=self.log) def try_run(self): # Get an initial list of all the service queues service_queues = {queue.decode('utf-8').lstrip('service-queue-'): None for queue in self.redis.keys(service_queue_name('*'))} while self.running: self.heartbeat() # Reset the status of the service queues service_queues = {service_name: None for service_name in service_queues} # Update the service queue status based on current list of services for service in self.datastore.list_all_services(full=True): service_queues[service.name] = service for service_name, service in service_queues.items(): if not service or not service.enabled or self.get_service_stage(service_name) != ServiceStage.Running: while True: task = self.dispatch_client.request_work(None, service_name=service_name, service_version='0', blocking=False) if task is None: break error = Error(dict( archive_ts=now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60), created='NOW', expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None, response=dict( message='The service was disabled while processing this task.', service_name=task.service_name, service_version='0', status='FAIL_NONRECOVERABLE', ), sha256=task.fileinfo.sha256, type="TASK PRE-EMPTED", )) error_key = error.build_key(task=task) self.dispatch_client.service_failed(task.sid, error_key, error) # Wait a while before checking status of all services again time.sleep(self.delay)
class TaskingClient: """A helper class to simplify tasking for privileged services and service-server. This tool helps take care of interactions between the filestore, datastore, dispatcher, and any sources of files to be processed. """ def __init__(self, datastore: AssemblylineDatastore = None, filestore: FileStore = None, config=None, redis=None, redis_persist=None, identify=None): self.log = logging.getLogger('assemblyline.tasking_client') self.config = config or forge.CachedObject(forge.get_config) self.datastore = datastore or forge.get_datastore(self.config) self.dispatch_client = DispatchClient(self.datastore, redis=redis, redis_persist=redis_persist) self.event_sender = EventSender('changes.services', redis) self.filestore = filestore or forge.get_filestore(self.config) self.heuristic_handler = HeuristicHandler(self.datastore) self.heuristics = { h.heur_id: h for h in self.datastore.list_all_heuristics() } self.status_table = ExpiringHash(SERVICE_STATE_HASH, ttl=60 * 30, host=redis) self.tag_safelister = forge.CachedObject(forge.get_tag_safelister, kwargs=dict( log=self.log, config=config, datastore=self.datastore), refresh=300) if identify: self.cleanup = False else: self.cleanup = True self.identify = identify or forge.get_identify( config=self.config, datastore=self.datastore, use_cache=True) def __enter__(self): return self def __exit__(self, *_): self.stop() def stop(self): if self.cleanup: self.identify.stop() @elasticapm.capture_span(span_type='tasking_client') def upload_file(self, file_path, classification, ttl, is_section_image, expected_sha256=None): # Identify the file info of the uploaded file file_info = self.identify.fileinfo(file_path) # Validate SHA256 of the uploaded file if expected_sha256 is None or expected_sha256 == file_info['sha256']: file_info['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) file_info['classification'] = classification if ttl: file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60) else: file_info['expiry_ts'] = None # Update the datastore with the uploaded file self.datastore.save_or_freshen_file( file_info['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=is_section_image) # Upload file to the filestore (upload already checks if the file exists) self.filestore.upload(file_path, file_info['sha256']) else: raise TaskingClientException( "Uploaded file does not match expected file hash. " f"[{file_info['sha256']} != {expected_sha256}]") # Service @elasticapm.capture_span(span_type='tasking_client') def register_service(self, service_data, log_prefix=""): keep_alive = True try: # Get heuristics list heuristics = service_data.pop('heuristics', None) # Patch update_channel, registry_type before Service registration object creation service_data['update_channel'] = service_data.get( 'update_channel', self.config.services.preferred_update_channel) service_data['docker_config']['registry_type'] = service_data['docker_config'] \ .get('registry_type', self.config.services.preferred_registry_type) service_data['privileged'] = service_data.get( 'privileged', self.config.services.prefer_service_privileged) for dep in service_data.get('dependencies', {}).values(): dep['container']['registry_type'] = dep.get( 'registry_type', self.config.services.preferred_registry_type) # Pop unused registration service_data for x in ['file_required', 'tool_version']: service_data.pop(x, None) # Create Service registration object service = Service(service_data) # Fix service version, we don't need to see the stable label service.version = service.version.replace('stable', '') # Save service if it doesn't already exist if not self.datastore.service.exists( f'{service.name}_{service.version}'): self.datastore.service.save( f'{service.name}_{service.version}', service) self.datastore.service.commit() self.log.info(f"{log_prefix}{service.name} registered") keep_alive = False # Save service delta if it doesn't already exist if not self.datastore.service_delta.exists(service.name): self.datastore.service_delta.save(service.name, {'version': service.version}) self.datastore.service_delta.commit() self.log.info(f"{log_prefix}{service.name} " f"version ({service.version}) registered") new_heuristics = [] if heuristics: plan = self.datastore.heuristic.get_bulk_plan() for index, heuristic in enumerate(heuristics): heuristic_id = f'#{index}' # Set heuristic id to it's position in the list for logging purposes try: # Append service name to heuristic ID heuristic[ 'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}" # Attack_id field is now a list, make it a list if we receive otherwise attack_id = heuristic.get('attack_id', None) if isinstance(attack_id, str): heuristic['attack_id'] = [attack_id] heuristic = Heuristic(heuristic) heuristic_id = heuristic.heur_id plan.add_upsert_operation(heuristic_id, heuristic) except Exception as e: msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}" self.log.exception(f"{log_prefix}{msg}") raise ValueError(msg) for item in self.datastore.heuristic.bulk(plan)['items']: if item['update']['result'] != "noop": new_heuristics.append(item['update']['_id']) self.log.info( f"{log_prefix}{service.name} " f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}" ) self.datastore.heuristic.commit() service_config = self.datastore.get_service_with_delta( service.name, as_obj=False) # Notify components watching for service config changes self.event_sender.send(service.name, { 'operation': Operation.Added, 'name': service.name }) except ValueError as e: # Catch errors when building Service or Heuristic model(s) raise e return dict(keep_alive=keep_alive, new_heuristics=new_heuristics, service_config=service_config or dict()) # Task @elasticapm.capture_span(span_type='tasking_client') def get_task(self, client_id, service_name, service_version, service_tool_version, metric_factory, status_expiry=None, timeout=30): if status_expiry is None: status_expiry = time.time() + timeout cache_found = False try: service_data = self.dispatch_client.service_data[service_name] except KeyError: raise ServiceMissingException( "The service you're asking task for does not exist, try later", 404) # Set the service status to Idle since we will be waiting for a task self.status_table.set( client_id, (service_name, ServiceStatus.Idle, status_expiry)) # Getting a new task task = self.dispatch_client.request_work(client_id, service_name, service_version, timeout=timeout) if not task: # We've reached the timeout and no task found in service queue return None, False # We've got a task to process, consider us busy self.status_table.set(client_id, (service_name, ServiceStatus.Running, time.time() + service_data.timeout)) metric_factory.increment('execute') result_key = Result.help_build_key( sha256=task.fileinfo.sha256, service_name=service_name, service_version=service_version, service_tool_version=service_tool_version, is_empty=False, task=task) # If we are allowed, try to see if the result has been cached if not task.ignore_cache and not service_data.disable_cache: # Checking for previous results for this key result = self.datastore.result.get_if_exists(result_key) if result: metric_factory.increment('cache_hit') if result.result.score: metric_factory.increment('scored') else: metric_factory.increment('not_scored') result.archive_ts = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) self.dispatch_client.service_finished(task.sid, result_key, result) cache_found = True if not cache_found: # Checking for previous empty results for this key result = self.datastore.emptyresult.get_if_exists( f"{result_key}.e") if result: metric_factory.increment('cache_hit') metric_factory.increment('not_scored') result = self.datastore.create_empty_result_from_key( result_key) self.dispatch_client.service_finished( task.sid, f"{result_key}.e", result) cache_found = True if not cache_found: metric_factory.increment('cache_miss') else: metric_factory.increment('cache_skipped') if not cache_found: # No luck with the cache, lets dispatch the task to a client return task.as_primitives(), False return None, True @elasticapm.capture_span(span_type='tasking_client') def task_finished(self, service_task, client_id, service_name, metric_factory): exec_time = service_task.get('exec_time') try: task = ServiceTask(service_task['task']) if 'result' in service_task: # Task created a result missing_files = self._handle_task_result( exec_time, task, service_task['result'], client_id, service_name, service_task['freshen'], metric_factory) if missing_files: return dict(success=False, missing_files=missing_files) return dict(success=True) elif 'error' in service_task: # Task created an error error = service_task['error'] self._handle_task_error(exec_time, task, error, client_id, service_name, metric_factory) return dict(success=True) else: return None except ValueError as e: # Catch errors when building Task or Result model raise e @elasticapm.capture_span(span_type='tasking_client') def _handle_task_result(self, exec_time: int, task: ServiceTask, result: Dict[str, Any], client_id, service_name, freshen: bool, metric_factory): def freshen_file(file_info_list, item): file_info = file_info_list.get(item['sha256'], None) if file_info is None or not self.filestore.exists(item['sha256']): return True else: file_info['archive_ts'] = archive_ts file_info['expiry_ts'] = expiry_ts file_info['classification'] = item['classification'] self.datastore.save_or_freshen_file( item['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=item.get('is_section_image', False)) return False archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) else: expiry_ts = None # Check if all files are in the filestore if freshen: missing_files = [] hashes = list( set([ f['sha256'] for f in result['response']['extracted'] + result['response']['supplementary'] ])) file_infos = self.datastore.file.multiget(hashes, as_obj=False, error_on_missing=False) with elasticapm.capture_span( name="handle_task_result.freshen_files", span_type="tasking_client"): with concurrent.futures.ThreadPoolExecutor( max_workers=5) as executor: res = { f['sha256']: executor.submit(freshen_file, file_infos, f) for f in result['response']['extracted'] + result['response']['supplementary'] } for k, v in res.items(): if v.result(): missing_files.append(k) if missing_files: return missing_files # Add scores to the heuristics, if any section set a heuristic with elasticapm.capture_span( name="handle_task_result.process_heuristics", span_type="tasking_client"): total_score = 0 for section in result['result']['sections']: zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True) section['tags'] = flatten(section['tags']) if section.get('heuristic'): heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}" section['heuristic']['heur_id'] = heur_id try: section[ 'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic( section['heuristic'], self.heuristics, zeroize_on_sig_safe) for tag in new_tags: section['tags'].setdefault(tag[0], []) if tag[1] not in section['tags'][tag[0]]: section['tags'][tag[0]].append(tag[1]) total_score += section['heuristic']['score'] except InvalidHeuristicException: section['heuristic'] = None # Update the total score of the result result['result']['score'] = total_score # Add timestamps for creation, archive and expiry result['created'] = now_as_iso() result['archive_ts'] = archive_ts result['expiry_ts'] = expiry_ts # Pop the temporary submission data temp_submission_data = result.pop('temp_submission_data', None) if temp_submission_data: old_submission_data = { row.name: row.value for row in task.temporary_submission_data } temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in old_submission_data or v != old_submission_data[k] } big_temp_data = { k: len(str(v)) for k, v in temp_submission_data.items() if len(str(v)) > self.config.submission.max_temp_data_length } if big_temp_data: big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()] self.log.warning( f"[{task.sid}] The following temporary submission keys where ignored because they are " "bigger then the maximum data size allowed " f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}" ) temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in big_temp_data } # Process the tag values with elasticapm.capture_span(name="handle_task_result.process_tags", span_type="tasking_client"): for section in result['result']['sections']: # Perform tag safelisting tags, safelisted_tags = self.tag_safelister.get_validated_tag_map( section['tags']) section['tags'] = unflatten(tags) section['safelisted_tags'] = safelisted_tags section['tags'], dropped = construct_safe( Tagging, section.get('tags', {})) # Set section score to zero and lower total score if service is set to zeroize score # and all tags were safelisted if section.pop('zeroize_on_tag_safe', False) and \ section.get('heuristic') and \ len(tags) == 0 and \ len(safelisted_tags) != 0: result['result']['score'] -= section['heuristic']['score'] section['heuristic']['score'] = 0 if dropped: self.log.warning( f"[{task.sid}] Invalid tag data from {service_name}: {dropped}" ) result = Result(result) result_key = result.build_key( service_tool_version=result.response.service_tool_version, task=task) self.dispatch_client.service_finished(task.sid, result_key, result, temp_submission_data) # Metrics if result.result.score > 0: metric_factory.increment('scored') else: metric_factory.increment('not_scored') self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}" ) self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5)) @elasticapm.capture_span(span_type='tasking_client') def _handle_task_error(self, exec_time: int, task: ServiceTask, error: Dict[str, Any], client_id, service_name, metric_factory) -> None: self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}" ) # Add timestamps for creation, archive and expiry error['created'] = now_as_iso() error['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60) error = Error(error) error_key = error.build_key( service_tool_version=error.response.service_tool_version, task=task) self.dispatch_client.service_failed(task.sid, error_key, error) # Metrics if error.response.status == 'FAIL_RECOVERABLE': metric_factory.increment('fail_recoverable') else: metric_factory.increment('fail_nonrecoverable') self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
def test_simple(clean_redis, clean_datastore): ds = clean_datastore redis = clean_redis def service_queue(name): return get_service_queue(name, redis) file = random_model_obj(File) file_hash = file.sha256 file.type = 'unknown' ds.file.save(file_hash, file) sub: Submission = random_model_obj(models.submission.Submission) sub.sid = sid = 'first-submission' sub.params.ignore_cache = False sub.params.max_extracted = 5 sub.params.classification = get_classification().UNRESTRICTED sub.params.initial_data = json.dumps({'cats': 'big'}) sub.files = [dict(sha256=file_hash, name='file')] disp = Dispatcher(ds, redis, redis) disp.running = ToggleTrue() client = DispatchClient(ds, redis, redis) client.dispatcher_data_age = time.time() client.dispatcher_data.append(disp.instance_id) # Submit a problem, and check that it gets added to the dispatch hash # and the right service queues logger.info('==== first dispatch') # task = SubmissionTask(sub.as_primitives(), 'some-completion-queue') client.dispatch_submission(sub) disp.pull_submissions() disp.service_worker(disp.process_queue_index(sid)) task = disp.tasks.get(sid) assert task.queue_keys[(file_hash, 'extract')] is not None assert task.queue_keys[(file_hash, 'wrench')] is not None assert service_queue('extract').length() == 1 assert service_queue('wrench').length() == 1 # Making the same call again will queue it up again logger.info('==== second dispatch') disp.dispatch_file(task, file_hash) assert task.queue_keys[(file_hash, 'extract')] is not None assert task.queue_keys[(file_hash, 'wrench')] is not None assert service_queue('extract').length() == 1 # the queue doesn't pile up assert service_queue('wrench').length() == 1 logger.info('==== third dispatch') job = client.request_work('0', 'extract', '0') assert job.temporary_submission_data == [{'name': 'cats', 'value': 'big'}] client.service_failed(sid, 'abc123', make_error(file_hash, 'extract')) # Deliberately do in the wrong order to make sure that works disp.pull_service_results() disp.service_worker(disp.process_queue_index(sid)) assert task.queue_keys[(file_hash, 'extract')] is not None assert task.queue_keys[(file_hash, 'wrench')] is not None assert service_queue('extract').length() == 1 # Mark extract as finished, wrench as failed logger.info('==== fourth dispatch') client.request_work('0', 'extract', '0') client.request_work('0', 'wrench', '0') client.service_finished(sid, 'extract-result', make_result(file_hash, 'extract')) client.service_failed(sid, 'wrench-error', make_error(file_hash, 'wrench', False)) for _ in range(2): disp.pull_service_results() disp.service_worker(disp.process_queue_index(sid)) assert wait_error(task, file_hash, 'wrench') assert wait_result(task, file_hash, 'extract') assert service_queue('av-a').length() == 1 assert service_queue('av-b').length() == 1 assert service_queue('frankenstrings').length() == 1 # Have the AVs fail, frankenstrings finishes logger.info('==== fifth dispatch') client.request_work('0', 'av-a', '0') client.request_work('0', 'av-b', '0') client.request_work('0', 'frankenstrings', '0') client.service_failed(sid, 'av-a-error', make_error(file_hash, 'av-a', False)) client.service_failed(sid, 'av-b-error', make_error(file_hash, 'av-b', False)) client.service_finished(sid, 'f-result', make_result(file_hash, 'frankenstrings')) for _ in range(3): disp.pull_service_results() disp.service_worker(disp.process_queue_index(sid)) assert wait_result(task, file_hash, 'frankenstrings') assert wait_error(task, file_hash, 'av-a') assert wait_error(task, file_hash, 'av-b') assert service_queue('xerox').length() == 1 # Finish the xerox service and check if the submission completion got checked logger.info('==== sixth dispatch') client.request_work('0', 'xerox', '0') client.service_finished(sid, 'xerox-result-key', make_result(file_hash, 'xerox')) disp.pull_service_results() disp.service_worker(disp.process_queue_index(sid)) disp.save_submission() assert wait_result(task, file_hash, 'xerox') assert disp.tasks.get(sid) is None
class Plumber(CoreBase): def __init__(self, logger=None, shutdown_timeout: float = None, config=None, redis=None, redis_persist=None, datastore=None, delay=60): super().__init__('plumber', logger, shutdown_timeout, config=config, redis=redis, redis_persist=redis_persist, datastore=datastore) self.delay = float(delay) self.dispatch_client = DispatchClient(datastore=self.datastore, redis=self.redis, redis_persist=self.redis_persist, logger=self.log) self.flush_threads: dict[str, threading.Thread] = {} self.stop_signals: dict[str, threading.Event] = {} self.service_limit: dict[str, int] = {} def stop(self): for sig in self.stop_signals.values(): sig.set() super().stop() def try_run(self): # Get an initial list of all the service queues service_queues: dict[str, Optional[Service]] service_queues = { queue.decode('utf-8').lstrip('service-queue-'): None for queue in self.redis.keys(service_queue_name('*')) } while self.running: # Reset the status of the service queues service_queues = { service_name: None for service_name in service_queues } # Update the service queue status based on current list of services for service in self.datastore.list_all_services(full=True): service_queues[service.name] = service for service_name, service in service_queues.items(): # For disabled or othewise unavailable services purge the queue current_stage = self.get_service_stage(service_name, ServiceStage.Running) if not service or not service.enabled or current_stage != ServiceStage.Running: while True: task = self.dispatch_client.request_work( 'plumber', service_name=service_name, service_version='0', blocking=False) if task is None: break error = Error( dict( archive_ts=now_as_iso(self.config.datastore. ilm.days_until_archive * 24 * 60 * 60), created='NOW', expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None, response=dict( message= 'The service was disabled while processing this task.', service_name=task.service_name, service_version='0', status='FAIL_NONRECOVERABLE', ), sha256=task.fileinfo.sha256, type="TASK PRE-EMPTED", )) error_key = error.build_key(task=task) self.dispatch_client.service_failed( task.sid, error_key, error) self.heartbeat() # For services that are enabled but limited if not service or not service.enabled or service.max_queue_length == 0: if service_name in self.stop_signals: self.stop_signals[service_name].set() self.service_limit.pop(service_name) self.flush_threads.pop(service_name) elif service and service.enabled and service.max_queue_length > 0: self.service_limit[service_name] = service.max_queue_length thread = self.flush_threads.get(service_name) if not thread or not thread.is_alive(): self.stop_signals[service_name] = threading.Event() thread = threading.Thread(target=self.watch_service, args=[service_name], daemon=True) self.flush_threads[service_name] = thread thread.start() # Wait a while before checking status of all services again self.sleep_with_heartbeat(self.delay) def watch_service(self, service_name): service_queue = get_service_queue(service_name, self.redis) while self.running and not self.stop_signals[service_name].is_set(): while service_queue.length() > self.service_limit[service_name]: task = self.dispatch_client.request_work( 'plumber', service_name=service_name, service_version='0', blocking=False, low_priority=True) if task is None: break error = Error( dict( archive_ts=now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60), created='NOW', expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None, response=dict( message="Task canceled due to execesive queuing.", service_name=task.service_name, service_version='0', status='FAIL_NONRECOVERABLE', ), sha256=task.fileinfo.sha256, type="TASK PRE-EMPTED", )) error_key = error.build_key(task=task) self.dispatch_client.service_failed(task.sid, error_key, error) self.sleep(2)
class MockService(ServerBase): """Replaces everything past the dispatcher. Including service API, in the future probably include that in this test. """ def __init__(self, name, datastore, redis, filestore): super().__init__('assemblyline.service.' + name) self.service_name = name self.datastore = datastore self.filestore = filestore self.queue = get_service_queue(name, redis) self.dispatch_client = DispatchClient(self.datastore, redis) self.hits = dict() self.drops = dict() def try_run(self): while self.running: task = self.dispatch_client.request_work('worker', self.service_name, '0', timeout=1) if not task: continue print(self.service_name, 'has received a job', task.sid) file = self.filestore.get(task.fileinfo.sha256) instructions = json.loads(file) instructions = instructions.get(self.service_name, {}) print(self.service_name, 'following instruction:', instructions) hits = self.hits[task.fileinfo.sha256] = self.hits.get( task.fileinfo.sha256, 0) + 1 if instructions.get('semaphore', False): _global_semaphore.acquire(blocking=True, timeout=instructions['semaphore']) continue if 'drop' in instructions: if instructions['drop'] >= hits: self.drops[task.fileinfo.sha256] = self.drops.get( task.fileinfo.sha256, 0) + 1 continue if instructions.get('failure', False): error = Error(instructions['error']) error.sha256 = task.fileinfo.sha256 self.dispatch_client.service_failed(task.sid, error=error, error_key=get_random_id()) continue result_data = { 'archive_ts': time.time() + 300, 'classification': 'U', 'response': { 'service_version': '0', 'service_tool_version': '0', 'service_name': self.service_name, }, 'result': {}, 'sha256': task.fileinfo.sha256, 'expiry_ts': time.time() + 600 } result_data.update(instructions.get('result', {})) result_data['response'].update(instructions.get('response', {})) result = Result(result_data) result_key = instructions.get('result_key', get_random_id()) self.dispatch_client.service_finished(task.sid, result_key, result)