def get_working_and_idle(redis, current_service): status_table = ExpiringHash(SERVICE_STATE_HASH, host=redis, ttl=30 * 60) service_data = status_table.items() busy = [] idle = [] for host, (service, state, time_limit) in service_data.items(): if service == current_service: if time.time() < time_limit: if state == ServiceStatus.Running: busy.append(host) else: idle.append(host) return busy, idle
class ScalerServer(CoreBase): def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): super().__init__('assemblyline.scaler', config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30 * 60) labels = { 'app': 'assemblyline', 'section': 'service', } if KUBERNETES_AL_CONFIG: self.log.info( f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}" ) self.controller = KubernetesController( logger=self.log, prefix='alsvc_', labels=labels, namespace=NAMESPACE, priority='al-service-priority') # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_CONFIGMAP: self.controller.config_mount( 'classification-config', config_map=CLASSIFICATION_CONFIGMAP, key=CLASSIFICATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/classification.yml') else: self.log.info("Loading Docker cluster interface.") self.controller = DockerController( logger=self.log, prefix=NAMESPACE, cpu_overallocation=self.config.core.scaler.cpu_overallocation, memory_overallocation=self.config.core.scaler. memory_overallocation, labels=labels) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_HOST_PATH: self.controller.global_mounts.append( (CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) self.profiles: Dict[str, ServiceProfile] = {} # Prepare a single threaded scheduler self.state = collection.Collection( period=self.config.core.metrics.export_interval) self.scheduler = sched.scheduler() self.scheduler_stopped = threading.Event() def add_service(self, profile: ServiceProfile): profile.desired_instances = max( self.controller.get_target(profile.name), profile.min_instances) profile.running_instances = profile.desired_instances self.log.debug( f'Starting service {profile.name} with a target of {profile.desired_instances}' ) profile.last_update = time.time() self.profiles[profile.name] = profile self.controller.add_profile(profile) def try_run(self): # Do an initial call to the main methods, who will then be registered with the scheduler self.sync_services() self.sync_metrics() self.update_scaling() self.expire_errors() self.process_timeouts() self.export_metrics() self.flush_service_status() self.log_container_events() self.heartbeat() # Run as long as we need to while self.running: delay = self.scheduler.run(False) time.sleep(min(delay, 2)) self.scheduler_stopped.set() def stop(self): super().stop() self.scheduler_stopped.wait(5) self.controller.stop() def heartbeat(self): """Periodically touch a file on disk. Since tasks are run serially, the delay between touches will be the maximum of HEARTBEAT_INTERVAL and the longest running task. """ if self.config.logging.heartbeat_file: self.scheduler.enter(HEARTBEAT_INTERVAL, 0, self.heartbeat) super().heartbeat() def sync_services(self): self.scheduler.enter(SERVICE_SYNC_INTERVAL, 0, self.sync_services) default_settings = self.config.core.scaler.service_defaults image_variables = defaultdict(str) image_variables.update(self.config.services.image_variables) current_services = set(self.profiles.keys()) discovered_services = [] # Get all the service data for service in self.datastore.list_all_services(full=True): service: Service = service name = service.name stage = self.get_service_stage(service.name) discovered_services.append(name) # noinspection PyBroadException try: if service.enabled and stage == ServiceStage.Off: # Enable this service's dependencies self.controller.prepare_network( service.name, service.docker_config.allow_internet_access) for _n, dependency in service.dependencies.items(): self.controller.start_stateful_container( service_name=service.name, container_name=_n, spec=dependency, labels={'dependency_for': service.name}) # Move to the next service stage if service.update_config and service.update_config.wait_for_update: self._service_stage_hash.set(name, ServiceStage.Update) else: self._service_stage_hash.set(name, ServiceStage.Running) if not service.enabled: self.stop_service(service.name, stage) continue # Check that all enabled services are enabled if service.enabled and stage == ServiceStage.Running: # Compute a hash of service properties not include in the docker config, that # should still result in a service being restarted when changed config_hash = hash(str(sorted(service.config.items()))) config_hash = hash( (config_hash, str(service.submission_params))) # Build the docker config for the service, we are going to either create it or # update it so we need to know what the current configuration is either way docker_config = service.docker_config docker_config.image = Template( docker_config.image).safe_substitute(image_variables) set_keys = set(var.name for var in docker_config.environment) for var in default_settings.environment: if var.name not in set_keys: docker_config.environment.append(var) # Add the service to the list of services being scaled if name not in self.profiles: self.log.info(f'Adding {service.name} to scaling') self.add_service( ServiceProfile( name=name, min_instances=default_settings.min_instances, growth=default_settings.growth, shrink=default_settings.shrink, config_hash=config_hash, backlog=default_settings.backlog, max_instances=service.licence_count, container_config=docker_config, queue=get_service_queue(name, self.redis), shutdown_seconds=service.timeout + 30, # Give service an extra 30 seconds to upload results )) # Update RAM, CPU, licence requirements for running services else: profile = self.profiles[name] if profile.container_config != docker_config or profile.config_hash != config_hash: self.log.info( f"Updating deployment information for {name}") profile.container_config = docker_config profile.config_hash = config_hash self.controller.restart(profile) self.log.info( f"Deployment information for {name} replaced") if service.licence_count == 0: profile._max_instances = float('inf') else: profile._max_instances = service.licence_count except Exception: self.log.exception( f"Error applying service settings from: {service.name}") self.handle_service_error(service.name) # Find any services we have running, that are no longer in the database and remove them for stray_service in current_services - set(discovered_services): stage = self.get_service_stage(stray_service) self.stop_service(stray_service, stage) def stop_service(self, name, current_stage): if current_stage != ServiceStage.Off: # Disable this service's dependencies self.controller.stop_containers(labels={'dependency_for': name}) # Mark this service as not running in the shared record self._service_stage_hash.set(name, ServiceStage.Off) # Stop any running disabled services if name in self.profiles or self.controller.get_target(name) > 0: self.log.info(f'Removing {name} from scaling') self.controller.set_target(name, 0) self.profiles.pop(name, None) def update_scaling(self): """Check if we need to scale any services up or down.""" self.scheduler.enter(SCALE_INTERVAL, 0, self.update_scaling) try: # Figure out what services are expected to be running and how many profiles: List[ServiceProfile] = list(self.profiles.values()) targets = { _p.name: self.controller.get_target(_p.name) for _p in profiles } for name, profile in self.profiles.items(): self.log.debug(f'{name}') self.log.debug( f'Instances \t{profile.min_instances} < {profile.desired_instances} | ' f'{targets[name]} < {profile.max_instances}') self.log.debug( f'Pressure \t{profile.shrink_threshold} < {profile.pressure} < {profile.growth_threshold}' ) # # 1. Any processes that want to release resources can always be approved first # for name, profile in self.profiles.items(): if targets[name] > profile.desired_instances: self.log.info( f"{name} wants less resources changing allocation " f"{targets[name]} -> {profile.desired_instances}") self.controller.set_target(name, profile.desired_instances) targets[name] = profile.desired_instances if not self.running: return # # 2. Any processes that aren't reaching their min_instances target must be given # more resources before anyone else is considered. # for name, profile in self.profiles.items(): if targets[name] < profile.min_instances: self.log.info( f"{name} isn't meeting minimum allocation " f"{targets[name]} -> {profile.min_instances}") self.controller.set_target(name, profile.min_instances) targets[name] = profile.min_instances # # 3. Try to estimate available resources, and based on some metric grant the # resources to each service that wants them. While this free memory # pool might be spread across many nodes, we are going to treat it like # it is one big one, and let the orchestration layer sort out the details. # free_cpu = self.controller.free_cpu() free_memory = self.controller.free_memory() # def trim(prof: List[ServiceProfile]): prof = [ _p for _p in prof if _p.desired_instances > targets[_p.name] ] drop = [ _p for _p in prof if _p.cpu > free_cpu or _p.ram > free_memory ] if drop: drop = {_p.name: (_p.cpu, _p.ram) for _p in drop} self.log.debug( f"Can't make more because not enough resources {drop}") prof = [ _p for _p in prof if _p.cpu <= free_cpu and _p.ram <= free_memory ] return prof profiles = trim(profiles) while profiles: # TODO do we need to add balancing metrics other than 'least running' for this? probably if True: profiles.sort( key=lambda _p: self.controller.get_target(_p.name)) # Add one for the profile at the bottom free_memory -= profiles[0].container_config.ram_mb free_cpu -= profiles[0].container_config.cpu_cores targets[profiles[0].name] += 1 # profiles = [_p for _p in profiles if _p.desired_instances > targets[_p.name]] # profiles = [_p for _p in profiles if _p.cpu < free_cpu and _p.ram < free_memory] profiles = trim(profiles) # Apply those adjustments we have made back to the controller for name, value in targets.items(): old = self.controller.get_target(name) if value != old: self.log.info(f"Scaling service {name}: {old} -> {value}") self.controller.set_target(name, value) if not self.running: return except ServiceControlError as error: self.log.exception("Error while scaling services.") self.handle_service_error(error.service_name) def handle_service_error(self, service_name): """Handle an error occurring in the *analysis* service. Errors for core systems should simply be logged, and a best effort to continue made. For analysis services, ignore the error a few times, then disable the service. """ self.error_count[service_name] = self.error_count.get(service_name, 0) + 1 if self.error_count[service_name] >= MAXIMUM_SERVICE_ERRORS: self.datastore.service_delta.update( service_name, [(self.datastore.service_delta.UPDATE_SET, 'enabled', False)]) del self.error_count[service_name] def sync_metrics(self): """Check if there are any pubsub messages we need.""" self.scheduler.enter(METRIC_SYNC_INTERVAL, 3, self.sync_metrics) # Pull service metrics from redis service_data = self.status_table.items() for host, (service, state, time_limit) in service_data.items(): # If an entry hasn't expired, take it into account if time.time() < time_limit: self.state.update(service=service, host=host, throughput=0, busy_seconds=METRIC_SYNC_INTERVAL if state == ServiceStatus.Running else 0) # If an entry expired a while ago, the host is probably not in use any more if time.time() > time_limit + 600: self.status_table.pop(host) # Check the set of services that might be sitting at zero instances, and if it is, we need to # manually check if it is offline export_interval = self.config.core.metrics.export_interval for profile_name, profile in self.profiles.items(): # Pull out statistics from the metrics regularization update = self.state.read(profile_name) if update: delta = time.time() - profile.last_update profile.update(delta=delta, backlog=profile.queue.length(), **update) # Check if we expect no messages, if so pull the queue length ourselves since there is no heartbeat if self.controller.get_target( profile_name ) == 0 and profile.desired_instances == 0 and profile.queue: queue_length = profile.queue.length() if queue_length > 0: self.log.info(f"Service at zero instances has messages: " f"{profile.name} ({queue_length} in queue)") profile.update(delta=export_interval, instances=0, backlog=queue_length, duty_cycle=profile.target_duty_cycle) # TODO maybe find another way of implementing this that is less aggressive # for profile_name, profile in self.profiles.items(): # # In the case that there should actually be instances running, but we haven't gotten # # any heartbeat messages we might be waiting for a container that can't start properly # if self.services.controller.get_target(profile_name) > 0: # if time.time() - profile.last_update > profile.shutdown_seconds: # self.log.error(f"Starting service {profile_name} has timed out " # f"({time.time() - profile.last_update} > {profile.shutdown_seconds} seconds)") # # # Disable the the service # self.datastore.service_delta.update(profile_name, [ # (self.datastore.service_delta.UPDATE_SET, 'enabled', False) # ]) def expire_errors(self): self.scheduler.enter(ERROR_EXPIRY_INTERVAL, 0, self.expire_errors) self.error_count = { name: err - 1 for name, err in self.error_count.items() if err > 1 } def process_timeouts(self): self.scheduler.enter(PROCESS_TIMEOUT_INTERVAL, 0, self.process_timeouts) while True: message = self.scaler_timeout_queue.pop(blocking=False) if not message: break # noinspection PyBroadException try: self.log.info( f"Killing service container: {message['container']} running: {message['service']}" ) self.controller.stop_container(message['service'], message['container']) except Exception: self.log.exception( f"Exception trying to stop timed out service container: {message}" ) def export_metrics(self): self.scheduler.enter(self.config.logging.export_interval, 0, self.export_metrics) for service_name, profile in self.profiles.items(): metrics = { 'running': profile.running_instances, 'target': profile.desired_instances, 'minimum': profile.min_instances, 'maximum': profile.instance_limit, 'dynamic_maximum': profile.max_instances, 'queue': profile.queue_length, 'duty_cycle': profile.duty_cycle, 'pressure': profile.pressure } export_metrics_once(service_name, Status, metrics, host=HOSTNAME, counter_type='scaler-status', config=self.config, redis=self.redis) memory, memory_total = self.controller.memory_info() cpu, cpu_total = self.controller.cpu_info() metrics = { 'memory_total': memory_total, 'cpu_total': cpu_total, 'memory_free': memory, 'cpu_free': cpu } export_metrics_once('scaler', Metrics, metrics, host=HOSTNAME, counter_type='scaler', config=self.config, redis=self.redis) def flush_service_status(self): """The service status table may have references to containers that have crashed. Try to remove them all.""" self.scheduler.enter(SERVICE_STATUS_FLUSH, 0, self.flush_service_status) # Pull all container names names = set(self.controller.get_running_container_names()) # Get the names we have status for for hostname in self.status_table.keys(): if hostname not in names: self.status_table.pop(hostname) def log_container_events(self): """The service status table may have references to containers that have crashed. Try to remove them all.""" self.scheduler.enter(CONTAINER_EVENTS_LOG_INTERVAL, 0, self.log_container_events) for message in self.controller.new_events(): self.log.warning("Container Event :: " + message)
def dispatch_file(self, task: FileTask): """ Handle a message describing a file to be processed. This file may be: - A new submission or extracted file. - A file that has just completed a stage of processing. - A file that has not completed a a stage of processing, but this call has been triggered by a timeout or similar. If the file is totally new, we will setup a dispatch table, and fill it in. Once we make/load a dispatch table, we will dispatch whichever group the table shows us hasn't been completed yet. When we dispatch to a service, we check if the task is already in the dispatch queue. If it isn't proceed normally. If it is, check that the service is still online. """ # Read the message content file_hash = task.file_info.sha256 active_task = self.active_submissions.get(task.sid) if active_task is None: self.log.warning(f"[{task.sid}] Untracked submission is being processed") return submission_task = SubmissionTask(active_task) submission = submission_task.submission # Refresh the watch on the submission, we are still working on it self.timeout_watcher.touch(key=task.sid, timeout=int(self.config.core.dispatcher.timeout), queue=SUBMISSION_QUEUE, message={'sid': task.sid}) # Open up the file/service table for this submission dispatch_table = DispatchHash(task.sid, self.redis, fetch_results=True) # Load things that we will need to fill out the file_tags = ExpiringSet(task.get_tag_set_name(), host=self.redis) file_tags_data = file_tags.members() temporary_submission_data = ExpiringHash(task.get_temporary_submission_data_name(), host=self.redis) temporary_data = [dict(name=row[0], value=row[1]) for row in temporary_submission_data.items().items()] # Calculate the schedule for the file schedule = self.build_schedule(dispatch_table, submission, file_hash, task.file_info.type) started_stages = [] # Go through each round of the schedule removing complete/failed services # Break when we find a stage that still needs processing outstanding = {} score = 0 errors = 0 while schedule and not outstanding: stage = schedule.pop(0) started_stages.append(stage) for service_name in stage: service = self.scheduler.services.get(service_name) if not service: continue # Load the results, if there are no results, then the service must be dispatched later # Don't look at if it has been dispatched, as multiple dispatches are fine, # but missing a dispatch isn't. finished = dispatch_table.finished(file_hash, service_name) if not finished: outstanding[service_name] = service continue # If the service terminated in an error, count the error and continue if finished.is_error: errors += 1 continue # if the service finished, count the score, and check if the file has been dropped score += finished.score if not submission.params.ignore_filtering and finished.drop: schedule.clear() if schedule: # If there are still stages in the schedule, over write them for next time dispatch_table.schedules.set(file_hash, started_stages) # Try to retry/dispatch any outstanding services if outstanding: self.log.info(f"[{task.sid}] File {file_hash} sent to services : {', '.join(list(outstanding.keys()))}") for service_name, service in outstanding.items(): # Find the actual file name from the list of files in submission filename = None for file in submission.files: if task.file_info.sha256 == file.sha256: filename = file.name break # Build the actual service dispatch message config = self.build_service_config(service, submission) service_task = ServiceTask(dict( sid=task.sid, metadata=submission.metadata, min_classification=task.min_classification, service_name=service_name, service_config=config, fileinfo=task.file_info, filename=filename or task.file_info.sha256, depth=task.depth, max_files=task.max_files, ttl=submission.params.ttl, ignore_cache=submission.params.ignore_cache, ignore_dynamic_recursion_prevention=submission.params.ignore_dynamic_recursion_prevention, tags=file_tags_data, temporary_submission_data=temporary_data, deep_scan=submission.params.deep_scan, priority=submission.params.priority, )) dispatch_table.dispatch(file_hash, service_name) queue = get_service_queue(service_name, self.redis) queue.push(service_task.priority, service_task.as_primitives()) else: # There are no outstanding services, this file is done # clean up the tags file_tags.delete() # If there are no outstanding ANYTHING for this submission, # send a message to the submission dispatcher to finalize self.counter.increment('files_completed') if dispatch_table.all_finished(): self.log.info(f"[{task.sid}] Finished processing file '{file_hash}' starting submission finalization.") self.submission_queue.push({'sid': submission.sid}) else: self.log.info(f"[{task.sid}] Finished processing file '{file_hash}'. Other files are not finished.")
def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key( sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results." ) return task = ServiceTask(task) # Check if the service is a candidate for dynamic recursion prevention if not task.ignore_dynamic_recursion_prevention: service_info = self.service_data.get(result.response.service_name, None) if service_info and service_info.category == "Dynamic Analysis": # TODO: This should be done in lua because it can introduce race condition in the future # but in the meantime it will remain this way while we can confirm it work as expected submission = self.active_submissions.get(sid) submission['submission']['params']['services'][ 'runtime_excluded'].append(result.response.service_name) self.active_submissions.set(sid, submission) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: with Lock(f"lock-{result_key}", 5, self.redis): old = self.ds.result.get(result_key) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None self.ds.result.save(result_key, result) # Let the logs know we have received a result for this task if result.drop_file: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key} but processing will stop after this service." ) else: self.log.debug( f"[{sid}/{result.sha256}] {task.service_name} succeeded. " f"Result will be stored in {result_key}") # Store the result object and mark the service finished in the global table process_table = DispatchHash(task.sid, self.redis) remaining, duplicate = process_table.finish( task.fileinfo.sha256, task.service_name, result_key, result.result.score, result.classification, result.drop_file) self.timeout_watcher.clear(f'{task.sid}-{task.key()}') if duplicate: self.log.warning( f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already " f"completed in the global processing table.") return # Push the result tags into redis new_tags = [] for section in result.result.sections: new_tags.extend(tag_dict_to_list(section.tags.as_primitives())) if new_tags: tag_set = ExpiringSet(get_tag_set_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) tag_set.add(*new_tags) # Update the temporary data table for this file temp_data_hash = ExpiringHash(get_temporary_submission_data_name( sid=task.sid, file_hash=task.fileinfo.sha256), host=self.redis) for key, value in (temporary_data or {}).items(): temp_data_hash.set(key, value) # Send the extracted files to the dispatcher depth_limit = self.config.submission.max_extraction_depth new_depth = task.depth + 1 if new_depth < depth_limit: # Prepare the temporary data from the parent to build the temporary data table for # these newly extract files parent_data = dict(temp_data_hash.items()) for extracted_data in result.response.extracted: if not process_table.add_file( extracted_data.sha256, task.max_files, parent_hash=task.fileinfo.sha256): if parent_data: child_hash_name = get_temporary_submission_data_name( task.sid, extracted_data.sha256) ExpiringHash(child_hash_name, host=self.redis).multi_set(parent_data) self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"Too many files extracted for submission {task.sid} " f"{extracted_data.sha256} extracted by " f"{task.service_name} will be dropped", 'service_name': task.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX FILES REACHED' })) continue file_data = self.files.get(extracted_data.sha256) self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.max( extracted_data.classification).value, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=new_depth, parent_hash=task.fileinfo.sha256, max_files=task.max_files)).as_primitives()) else: for extracted_data in result.response.extracted: self._dispatching_error( task, process_table, Error({ 'archive_ts': result.archive_ts, 'expiry_ts': result.expiry_ts, 'response': { 'message': f"{task.service_name} has extracted a file " f"{extracted_data.sha256} beyond the depth limits", 'service_name': result.response.service_name, 'service_tool_version': result.response.service_tool_version, 'service_version': result.response.service_version, 'status': 'FAIL_NONRECOVERABLE' }, 'sha256': extracted_data.sha256, 'type': 'MAX DEPTH REACHED' })) # If the global table said that this was the last outstanding service, # send a message to the dispatchers. if remaining <= 0: self.file_queue.push( FileTask( dict(sid=task.sid, min_classification=task.min_classification.value, file_info=task.fileinfo, depth=task.depth, max_files=task.max_files)).as_primitives()) # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w).push(msg)
class ScalerServer(ThreadedCoreBase): def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): super().__init__('assemblyline.scaler', config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count_lock = threading.Lock() self.error_count: dict[str, list[float]] = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30 * 60) self.service_event_sender = EventSender('changes.services', host=self.redis) self.service_change_watcher = EventWatcher( self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) core_env: dict[str, str] = {} # If we have privileged services, we must be able to pass the necessary environment variables for them to # function properly. for secret in re.findall( r'\${\w+}', open('/etc/assemblyline/config.yml', 'r').read()) + ['UI_SERVER']: env_name = secret.strip("${}") core_env[env_name] = os.environ[env_name] labels = { 'app': 'assemblyline', 'section': 'service', 'privilege': 'service' } if self.config.core.scaler.additional_labels: labels.update({ k: v for k, v in ( _l.split("=") for _l in self.config.core.scaler.additional_labels) }) if KUBERNETES_AL_CONFIG: self.log.info( f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}" ) self.controller = KubernetesController( logger=self.log, prefix='alsvc_', labels=labels, namespace=NAMESPACE, priority='al-service-priority', cpu_reservation=self.config.services.cpu_reservation, log_level=self.config.logging.log_level, core_env=core_env) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_CONFIGMAP: self.controller.config_mount( 'classification-config', config_map=CLASSIFICATION_CONFIGMAP, key=CLASSIFICATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/classification.yml') if CONFIGURATION_CONFIGMAP: self.controller.core_config_mount( 'assemblyline-config', config_map=CONFIGURATION_CONFIGMAP, key=CONFIGURATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/config.yml') else: self.log.info("Loading Docker cluster interface.") self.controller = DockerController( logger=self.log, prefix=NAMESPACE, labels=labels, log_level=self.config.logging.log_level, core_env=core_env) self._service_stage_hash.delete() if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME: self.controller.core_mounts.append( (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/')) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'), 'w') as handle: yaml.dump(self.config.as_primitives(), handle) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'classification.yml'), 'w') as handle: yaml.dump(get_classification().original_definition, handle) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_HOST_PATH: self.controller.global_mounts.append( (CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) # Information about services self.profiles: dict[str, ServiceProfile] = {} self.profiles_lock = threading.RLock() # Prepare a single threaded scheduler self.state = collection.Collection( period=self.config.core.metrics.export_interval) self.stopping = threading.Event() self.main_loop_exit = threading.Event() # Load the APM connection if any self.apm_client = None if self.config.core.metrics.apm_server.server_url: elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="scaler") def log_crashes(self, fn): @functools.wraps(fn) def with_logs(*args, **kwargs): # noinspection PyBroadException try: fn(*args, **kwargs) except ServiceControlError as error: self.log.exception( f"Error while managing service: {error.service_name}") self.handle_service_error(error.service_name) except Exception: self.log.exception(f'Crash in scaler: {fn.__name__}') return with_logs @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def add_service(self, profile: ServiceProfile): # We need to hold the lock the whole time we add the service, # we don't want the scaling thread trying to adjust the scale of a # deployment we haven't added to the system yet with self.profiles_lock: profile.desired_instances = max( self.controller.get_target(profile.name), profile.min_instances) profile.running_instances = profile.desired_instances profile.target_instances = profile.desired_instances self.log.debug( f'Starting service {profile.name} with a target of {profile.desired_instances}' ) profile.last_update = time.time() self.profiles[profile.name] = profile self.controller.add_profile(profile, scale=profile.desired_instances) def try_run(self): self.service_change_watcher.start() self.maintain_threads({ 'Log Container Events': self.log_container_events, 'Process Timeouts': self.process_timeouts, 'Service Configuration Sync': self.sync_services, 'Service Adjuster': self.update_scaling, 'Import Metrics': self.sync_metrics, 'Export Metrics': self.export_metrics, }) def stop(self): super().stop() self.service_change_watcher.stop() self.controller.stop() def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Removed: self.log.info( f'Service appears to be deleted, removing {data.name}') stage = self.get_service_stage(data.name) self.stop_service(data.name, stage) elif data.operation == Operation.Incompatible: return else: self._sync_service(self.datastore.get_service_with_delta( data.name)) def sync_services(self): while self.running: with apm_span(self.apm_client, 'sync_services'): with self.profiles_lock: current_services = set(self.profiles.keys()) discovered_services: list[str] = [] # Get all the service data for service in self.datastore.list_all_services(full=True): self._sync_service(service) discovered_services.append(service.name) # Find any services we have running, that are no longer in the database and remove them for stray_service in current_services - set( discovered_services): self.log.info( f'Service appears to be deleted, removing stray {stray_service}' ) stage = self.get_service_stage(stray_service) self.stop_service(stray_service, stage) self.sleep(SERVICE_SYNC_INTERVAL) def _sync_service(self, service: Service): name = service.name stage = self.get_service_stage(service.name) default_settings = self.config.core.scaler.service_defaults image_variables: defaultdict[str, str] = defaultdict(str) image_variables.update(self.config.services.image_variables) def prepare_container(docker_config: DockerConfig) -> DockerConfig: docker_config.image = Template( docker_config.image).safe_substitute(image_variables) set_keys = set(var.name for var in docker_config.environment) for var in default_settings.environment: if var.name not in set_keys: docker_config.environment.append(var) return docker_config # noinspection PyBroadException try: def disable_incompatible_service(): service.enabled = False if self.datastore.service_delta.update(service.name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service.name, { 'operation': Operation.Incompatible, 'name': service.name }) # Check if service considered compatible to run on Assemblyline? system_spec = f'{FRAMEWORK_VERSION}.{SYSTEM_VERSION}' if not service.version.startswith(system_spec): # If FW and SYS version don't prefix in the service version, we can't guarantee the service is compatible # Disable and treat it as incompatible due to service version. self.log.warning( "Disabling service with incompatible version. " f"[{service.version} != '{system_spec}.X.{service.update_channel}Y']." ) disable_incompatible_service() elif service.update_config and service.update_config.wait_for_update and not service.update_config.sources: # All signatures sources from a signature-dependent service was removed # Disable and treat it as incompatible due to service configuration relative to source management self.log.warning( "Disabling service with incompatible service configuration. " "Signature-dependent service has no signature sources.") disable_incompatible_service() if not service.enabled: self.stop_service(service.name, stage) return # Build the docker config for the dependencies. For now the dependency blob values # aren't set for the change key going to kubernetes because everything about # the dependency config should be captured in change key that the function generates # internally. A change key is set for the service deployment as that includes # things like the submission params dependency_config: dict[str, Any] = {} dependency_blobs: dict[str, str] = {} for _n, dependency in service.dependencies.items(): dependency.container = prepare_container(dependency.container) dependency_config[_n] = dependency dep_hash = get_id_from_data(dependency, length=16) dependency_blobs[ _n] = f"dh={dep_hash}v={service.version}p={service.privileged}" # Check if the service dependencies have been deployed. dependency_keys = [] updater_ready = stage == ServiceStage.Running if service.update_config: for _n, dependency in dependency_config.items(): key = self.controller.stateful_container_key( service.name, _n, dependency, '') if key: dependency_keys.append(_n + key) else: updater_ready = False # If stage is not set to running or a dependency container is missing start the setup process if not updater_ready: self.log.info(f'Preparing environment for {service.name}') # Move to the next service stage (do this first because the container we are starting may care) if service.update_config and service.update_config.wait_for_update: self._service_stage_hash.set(name, ServiceStage.Update) stage = ServiceStage.Update else: self._service_stage_hash.set(name, ServiceStage.Running) stage = ServiceStage.Running # Enable this service's dependencies before trying to launch the service containers dependency_internet = [ (name, dependency.container.allow_internet_access) for name, dependency in dependency_config.items() ] self.controller.prepare_network( service.name, service.docker_config.allow_internet_access, dependency_internet) for _n, dependency in dependency_config.items(): self.log.info(f'Launching {service.name} dependency {_n}') self.controller.start_stateful_container( service_name=service.name, container_name=_n, spec=dependency, labels={'dependency_for': service.name}, change_key=dependency_blobs.get(_n, '')) # If the conditions for running are met deploy or update service containers if stage == ServiceStage.Running: # Build the docker config for the service, we are going to either create it or # update it so we need to know what the current configuration is either way docker_config = prepare_container(service.docker_config) # Compute a blob of service properties not include in the docker config, that # should still result in a service being restarted when changed cfg_items = get_recursive_sorted_tuples(service.config) dep_keys = ''.join(sorted(dependency_keys)) config_blob = ( f"c={cfg_items}sp={service.submission_params}" f"dk={dep_keys}p={service.privileged}d={docker_config}") # Add the service to the list of services being scaled with self.profiles_lock: if name not in self.profiles: self.log.info( f"Adding " f"{f'privileged {service.name}' if service.privileged else service.name}" " to scaling") self.add_service( ServiceProfile( name=name, min_instances=default_settings.min_instances, growth=default_settings.growth, shrink=default_settings.shrink, config_blob=config_blob, dependency_blobs=dependency_blobs, backlog=default_settings.backlog, max_instances=service.licence_count, container_config=docker_config, queue=get_service_queue(name, self.redis), # Give service an extra 30 seconds to upload results shutdown_seconds=service.timeout + 30, privileged=service.privileged)) # Update RAM, CPU, licence requirements for running services else: profile = self.profiles[name] profile.max_instances = service.licence_count profile.privileged = service.privileged for dependency_name, dependency_blob in dependency_blobs.items( ): if profile.dependency_blobs[ dependency_name] != dependency_blob: self.log.info( f"Updating deployment information for {name}/{dependency_name}" ) profile.dependency_blobs[ dependency_name] = dependency_blob self.controller.start_stateful_container( service_name=service.name, container_name=dependency_name, spec=dependency_config[dependency_name], labels={'dependency_for': service.name}, change_key=dependency_blob) if profile.config_blob != config_blob: self.log.info( f"Updating deployment information for {name}") profile.container_config = docker_config profile.config_blob = config_blob self.controller.restart(profile) self.log.info( f"Deployment information for {name} replaced") except Exception: self.log.exception( f"Error applying service settings from: {service.name}") self.handle_service_error(service.name) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def stop_service(self, name: str, current_stage: ServiceStage): if current_stage != ServiceStage.Off: # Disable this service's dependencies self.controller.stop_containers(labels={'dependency_for': name}) # Mark this service as not running in the shared record self._service_stage_hash.set(name, ServiceStage.Off) # Stop any running disabled services if name in self.profiles or self.controller.get_target(name) > 0: self.log.info(f'Removing {name} from scaling') with self.profiles_lock: self.profiles.pop(name, None) self.controller.set_target(name, 0) def update_scaling(self): """Check if we need to scale any services up or down.""" pool = Pool() while self.sleep(SCALE_INTERVAL): with apm_span(self.apm_client, 'update_scaling'): # Figure out what services are expected to be running and how many with elasticapm.capture_span('read_profiles'): with self.profiles_lock: all_profiles: dict[str, ServiceProfile] = copy.deepcopy( self.profiles) raw_targets = self.controller.get_targets() targets = { _p.name: raw_targets.get(_p.name, 0) for _p in all_profiles.values() } for name, profile in all_profiles.items(): self.log.debug(f'{name}') self.log.debug( f'Instances \t{profile.min_instances} < {profile.desired_instances} | ' f'{targets[name]} < {profile.max_instances}') self.log.debug( f'Pressure \t{profile.shrink_threshold} < ' f'{profile.pressure} < {profile.growth_threshold}') # # 1. Any processes that want to release resources can always be approved first # with pool: for name, profile in all_profiles.items(): if targets[name] > profile.desired_instances: self.log.info( f"{name} wants less resources changing allocation " f"{targets[name]} -> {profile.desired_instances}" ) pool.call(self.controller.set_target, name, profile.desired_instances) targets[name] = profile.desired_instances # # 2. Any processes that aren't reaching their min_instances target must be given # more resources before anyone else is considered. # for name, profile in all_profiles.items(): if targets[name] < profile.min_instances: self.log.info( f"{name} isn't meeting minimum allocation " f"{targets[name]} -> {profile.min_instances}") pool.call(self.controller.set_target, name, profile.min_instances) targets[name] = profile.min_instances # # 3. Try to estimate available resources, and based on some metric grant the # resources to each service that wants them. While this free memory # pool might be spread across many nodes, we are going to treat it like # it is one big one, and let the orchestration layer sort out the details. # # Recalculate the amount of free resources expanding the total quantity by the overallocation free_cpu, total_cpu = self.controller.cpu_info() used_cpu = total_cpu - free_cpu free_cpu = total_cpu * self.config.core.scaler.cpu_overallocation - used_cpu free_memory, total_memory = self.controller.memory_info() used_memory = total_memory - free_memory free_memory = total_memory * self.config.core.scaler.memory_overallocation - used_memory # def trim(prof: list[ServiceProfile]): prof = [ _p for _p in prof if _p.desired_instances > targets[_p.name] ] drop = [ _p for _p in prof if _p.cpu > free_cpu or _p.ram > free_memory ] if drop: summary = {_p.name: (_p.cpu, _p.ram) for _p in drop} self.log.debug( f"Can't make more because not enough resources {summary}" ) prof = [ _p for _p in prof if _p.cpu <= free_cpu and _p.ram <= free_memory ] return prof remaining_profiles: list[ServiceProfile] = trim( list(all_profiles.values())) # The target values up until now should be in sync with the container orchestrator # create a copy, so we can track which ones change in the following loop old_targets = dict(targets) while remaining_profiles: # TODO do we need to add balancing metrics other than 'least running' for this? probably remaining_profiles.sort(key=lambda _p: targets[_p.name]) # Add one for the profile at the bottom free_memory -= remaining_profiles[ 0].container_config.ram_mb free_cpu -= remaining_profiles[ 0].container_config.cpu_cores targets[remaining_profiles[0].name] += 1 # Take out any services that should be happy now remaining_profiles = trim(remaining_profiles) # Apply those adjustments we have made back to the controller with elasticapm.capture_span('write_targets'): with pool: for name, value in targets.items(): if name not in self.profiles: # A service was probably added/removed while we were # in the middle of this function continue self.profiles[name].target_instances = value old = old_targets[name] if value != old: self.log.info( f"Scaling service {name}: {old} -> {value}" ) pool.call(self.controller.set_target, name, value) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def handle_service_error(self, service_name: str): """Handle an error occurring in the *analysis* service. Errors for core systems should simply be logged, and a best effort to continue made. For analysis services, ignore the error a few times, then disable the service. """ with self.error_count_lock: try: self.error_count[service_name].append(time.time()) except KeyError: self.error_count[service_name] = [time.time()] self.error_count[service_name] = [ _t for _t in self.error_count[service_name] if _t >= time.time() - ERROR_EXPIRY_TIME ] if len(self.error_count[service_name]) >= MAXIMUM_SERVICE_ERRORS: self.log.warning( f"Scaler has encountered too many errors trying to load {service_name}. " "The service will be permanently disabled...") if self.datastore.service_delta.update(service_name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service_name, { 'operation': Operation.Modified, 'name': service_name }) del self.error_count[service_name] def sync_metrics(self): """Check if there are any pub-sub messages we need.""" while self.sleep(METRIC_SYNC_INTERVAL): with apm_span(self.apm_client, 'sync_metrics'): # Pull service metrics from redis service_data = self.status_table.items() for host, (service, state, time_limit) in service_data.items(): # If an entry hasn't expired, take it into account if time.time() < time_limit: self.state.update( service=service, host=host, throughput=0, busy_seconds=METRIC_SYNC_INTERVAL if state == ServiceStatus.Running else 0) # If an entry expired a while ago, the host is probably not in use any more if time.time() > time_limit + 600: self.status_table.pop(host) # Download the current targets in the orchestrator while not holding the lock with self.profiles_lock: targets = { name: profile.target_instances for name, profile in self.profiles.items() } # Check the set of services that might be sitting at zero instances, and if it is, we need to # manually check if it is offline export_interval = self.config.core.metrics.export_interval with self.profiles_lock: queues = [ profile.queue for profile in self.profiles.values() if profile.queue ] lengths_list = pq_length(*queues) lengths = {_q: _l for _q, _l in zip(queues, lengths_list)} for profile_name, profile in self.profiles.items(): queue_length = lengths.get(profile.queue, 0) # Pull out statistics from the metrics regularization update = self.state.read(profile_name) if update: delta = time.time() - profile.last_update profile.update(delta=delta, backlog=queue_length, **update) # Check if we expect no messages, if so pull the queue length ourselves # since there is no heartbeat if targets.get( profile_name ) == 0 and profile.desired_instances == 0 and profile.queue: if queue_length > 0: self.log.info( f"Service at zero instances has messages: " f"{profile.name} ({queue_length} in queue)" ) profile.update(delta=export_interval, instances=0, backlog=queue_length, duty_cycle=profile.high_duty_cycle) def _timeout_kill(self, service, container): with apm_span(self.apm_client, 'timeout_kill'): self.controller.stop_container(service, container) self.status_table.pop(container) def process_timeouts(self): with concurrent.futures.ThreadPoolExecutor(10) as pool: futures = [] while self.running: message = self.scaler_timeout_queue.pop(blocking=True, timeout=1) if not message: continue with apm_span(self.apm_client, 'process_timeouts'): # Process new messages self.log.info( f"Killing service container: {message['container']} running: {message['service']}" ) futures.append( pool.submit(self._timeout_kill, message['service'], message['container'])) # Process finished finished = [_f for _f in futures if _f.done()] futures = [_f for _f in futures if _f not in finished] for _f in finished: exception = _f.exception() if exception is not None: self.log.error( f"Exception trying to stop timed out service container: {exception}" ) def export_metrics(self): while self.sleep(self.config.logging.export_interval): with apm_span(self.apm_client, 'export_metrics'): service_metrics = {} with self.profiles_lock: for service_name, profile in self.profiles.items(): service_metrics[service_name] = { 'running': profile.running_instances, 'target': profile.target_instances, 'minimum': profile.min_instances, 'maximum': profile.instance_limit, 'dynamic_maximum': profile.max_instances, 'queue': profile.queue_length, 'duty_cycle': profile.duty_cycle, 'pressure': profile.pressure } for service_name, metrics in service_metrics.items(): export_metrics_once(service_name, Status, metrics, host=HOSTNAME, counter_type='scaler_status', config=self.config, redis=self.redis) memory, memory_total = self.controller.memory_info() cpu, cpu_total = self.controller.cpu_info() metrics = { 'memory_total': memory_total, 'cpu_total': cpu_total, 'memory_free': memory, 'cpu_free': cpu } export_metrics_once('scaler', Metrics, metrics, host=HOSTNAME, counter_type='scaler', config=self.config, redis=self.redis) def log_container_events(self): """The service status table may have references to containers that have crashed. Try to remove them all.""" while self.sleep(CONTAINER_EVENTS_LOG_INTERVAL): with apm_span(self.apm_client, 'log_container_events'): for message in self.controller.new_events(): self.log.warning("Container Event :: " + message)