def test_pattern_event(redis_connection: Redis[Any]): calls: list[dict[str, Any]] = [] def _track_call(data: dict[str, Any]): calls.append(data) watcher = EventWatcher(redis_connection) try: watcher.register('changes.*', _track_call) watcher.start() sender = EventSender('changes.', redis_connection) start = time.time() while len(calls) < 5: sender.send(uuid.uuid4().hex, {'payload': 100}) if time.time() - start > 10: pytest.fail() assert len(calls) >= 5 for row in calls: assert row == {'payload': 100} finally: watcher.stop()
def __init__(self, datastore: AssemblylineDatastore = None, filestore: FileStore = None, config=None, redis=None, redis_persist=None, identify=None): self.log = logging.getLogger('assemblyline.tasking_client') self.config = config or forge.CachedObject(forge.get_config) self.datastore = datastore or forge.get_datastore(self.config) self.dispatch_client = DispatchClient(self.datastore, redis=redis, redis_persist=redis_persist) self.event_sender = EventSender('changes.services', redis) self.filestore = filestore or forge.get_filestore(self.config) self.heuristic_handler = HeuristicHandler(self.datastore) self.heuristics = { h.heur_id: h for h in self.datastore.list_all_heuristics() } self.status_table = ExpiringHash(SERVICE_STATE_HASH, ttl=60 * 30, host=redis) self.tag_safelister = forge.CachedObject(forge.get_tag_safelister, kwargs=dict( log=self.log, config=config, datastore=self.datastore), refresh=300) if identify: self.cleanup = False else: self.cleanup = True self.identify = identify or forge.get_identify( config=self.config, datastore=self.datastore, use_cache=True)
def __init__(self, logger: logging.Logger = None, shutdown_timeout: float = None, config: Config = None, datastore: AssemblylineDatastore = None, redis: RedisType = None, redis_persist: RedisType = None, default_pattern=".*"): self.updater_type = os.environ['SERVICE_PATH'].split('.')[-1].lower() self.default_pattern = default_pattern if not logger: al_log.init_logging(f'updater.{self.updater_type}', log_level=os.environ.get('LOG_LEVEL', "WARNING")) logger = logging.getLogger(f'assemblyline.updater.{self.updater_type}') super().__init__(f'assemblyline.{SERVICE_NAME}_updater', logger=logger, shutdown_timeout=shutdown_timeout, config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.update_data_hash = Hash(f'service-updates-{SERVICE_NAME}', self.redis_persist) self._update_dir = None self._update_tar = None self._time_keeper = None self._service: Optional[Service] = None self.event_sender = EventSender('changes.services', host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port) self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register(f'changes.services.{SERVICE_NAME}', self._handle_service_change_event) self.signature_change_watcher = EventWatcher(self.redis, deserializer=SignatureChange.deserialize) self.signature_change_watcher.register(f'changes.signatures.{SERVICE_NAME.lower()}', self._handle_signature_change_event) # A event flag that gets set when an update should be run for # reasons other than it being the regular interval (eg, change in signatures) self.source_update_flag = threading.Event() self.local_update_flag = threading.Event() self.local_update_start = threading.Event() # Load threads self._internal_server = None self.expected_threads = { 'Sync Service Settings': self._sync_settings, 'Outward HTTP Server': self._run_http, 'Internal HTTP Server': self._run_internal_http, 'Run source updates': self._run_source_updates, 'Run local updates': self._run_local_updates, } # Only used by updater with 'generates_signatures: false' self.latest_updates_dir = os.path.join(UPDATER_DIR, 'latest_updates') if not os.path.exists(self.latest_updates_dir): os.makedirs(self.latest_updates_dir)
def __init__(self, redis_persist=None, redis=None, logger=None, datastore=None): super().__init__('assemblyline.service.updater', logger=logger, datastore=datastore, redis_persist=redis_persist, redis=redis) self.container_update: Hash[dict[str, Any]] = Hash('container-update', self.redis_persist) self.latest_service_tags: Hash[dict[str, str]] = Hash('service-tags', self.redis_persist) self.service_events = EventSender('changes.services', host=self.redis) self.incompatible_services = set() self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE: extra_labels = {} if self.config.core.scaler.additional_labels: extra_labels = {k: v for k, v in (_l.split("=") for _l in self.config.core.scaler.additional_labels)} self.controller = KubernetesUpdateInterface(prefix='alsvc_', namespace=NAMESPACE, priority_class='al-core-priority', extra_labels=extra_labels, log_level=self.config.logging.log_level) else: self.controller = DockerUpdateInterface(log_level=self.config.logging.log_level)
class ServiceUpdater(ThreadedCoreBase): def __init__(self, logger: logging.Logger = None, shutdown_timeout: float = None, config: Config = None, datastore: AssemblylineDatastore = None, redis: RedisType = None, redis_persist: RedisType = None, default_pattern=".*"): self.updater_type = os.environ['SERVICE_PATH'].split('.')[-1].lower() self.default_pattern = default_pattern if not logger: al_log.init_logging(f'updater.{self.updater_type}', log_level=os.environ.get('LOG_LEVEL', "WARNING")) logger = logging.getLogger(f'assemblyline.updater.{self.updater_type}') super().__init__(f'assemblyline.{SERVICE_NAME}_updater', logger=logger, shutdown_timeout=shutdown_timeout, config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.update_data_hash = Hash(f'service-updates-{SERVICE_NAME}', self.redis_persist) self._update_dir = None self._update_tar = None self._time_keeper = None self._service: Optional[Service] = None self.event_sender = EventSender('changes.services', host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port) self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register(f'changes.services.{SERVICE_NAME}', self._handle_service_change_event) self.signature_change_watcher = EventWatcher(self.redis, deserializer=SignatureChange.deserialize) self.signature_change_watcher.register(f'changes.signatures.{SERVICE_NAME.lower()}', self._handle_signature_change_event) # A event flag that gets set when an update should be run for # reasons other than it being the regular interval (eg, change in signatures) self.source_update_flag = threading.Event() self.local_update_flag = threading.Event() self.local_update_start = threading.Event() # Load threads self._internal_server = None self.expected_threads = { 'Sync Service Settings': self._sync_settings, 'Outward HTTP Server': self._run_http, 'Internal HTTP Server': self._run_internal_http, 'Run source updates': self._run_source_updates, 'Run local updates': self._run_local_updates, } # Only used by updater with 'generates_signatures: false' self.latest_updates_dir = os.path.join(UPDATER_DIR, 'latest_updates') if not os.path.exists(self.latest_updates_dir): os.makedirs(self.latest_updates_dir) def trigger_update(self): self.source_update_flag.set() def update_directory(self): return self._update_dir def update_tar(self): return self._update_tar def get_active_config_hash(self) -> int: return self.update_data_hash.get(CONFIG_HASH_KEY) or 0 def set_active_config_hash(self, config_hash: int): self.update_data_hash.set(CONFIG_HASH_KEY, config_hash) def get_source_update_time(self) -> float: return self.update_data_hash.get(SOURCE_UPDATE_TIME_KEY) or 0 def set_source_update_time(self, update_time: float): self.update_data_hash.set(SOURCE_UPDATE_TIME_KEY, update_time) def get_source_extra(self) -> dict[str, Any]: return self.update_data_hash.get(SOURCE_EXTRA_KEY) or {} def set_source_extra(self, extra_data: dict[str, Any]): self.update_data_hash.set(SOURCE_EXTRA_KEY, extra_data) def get_local_update_time(self) -> float: if self._time_keeper: return os.path.getctime(self._time_keeper) return 0 def status(self): return { 'local_update_time': self.get_local_update_time(), 'download_available': self._update_dir is not None, '_directory': self._update_dir, '_tar': self._update_tar, } def stop(self): super().stop() self.signature_change_watcher.stop() self.service_change_watcher.stop() self.source_update_flag.set() self.local_update_flag.set() self.local_update_start.set() if self._internal_server: self._internal_server.shutdown() def try_run(self): self.signature_change_watcher.start() self.service_change_watcher.start() self.maintain_threads(self.expected_threads) def _run_internal_http(self): """run backend insecure http server A small inprocess server to syncronize info between gunicorn and the updater daemon. This HTTP server is not safe for exposing externally, but fine for IPC. """ them = self class Handler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() self.wfile.write(json.dumps(them.status()).encode()) def log_error(self, format: str, *args: Any): them.log.info(format % args) def log_message(self, format: str, *args: Any): them.log.debug(format % args) self._internal_server = ThreadingHTTPServer(('0.0.0.0', 9999), Handler) self._internal_server.serve_forever() def _run_http(self): # Start a server for our http interface in a separate process my_env = os.environ.copy() proc = subprocess.Popen(["gunicorn", "assemblyline_v4_service.updater.app:app", "--config=python:assemblyline_v4_service.updater.gunicorn_config"], env=my_env) while self.sleep(1): if proc.poll() is not None: break # If we have left the loop and the process is still alive, stop it. if proc.poll() is not None: proc.terminate() proc.wait() @staticmethod def config_hash(service: Service) -> int: if service is None: return 0 return hash(json.dumps(service.update_config.as_primitives())) def _handle_signature_change_event(self, data: SignatureChange): self.local_update_flag.set() def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Modified: self._pull_settings() def _sync_settings(self): # Download the service object from datastore self._service = self.datastore.get_service_with_delta(SERVICE_NAME) while self.sleep(SERVICE_PULL_INTERVAL): self._pull_settings() def _pull_settings(self): # Download the service object from datastore self._service = self.datastore.get_service_with_delta(SERVICE_NAME) # If the update configuration for the service has changed, trigger an update if self.config_hash(self._service) != self.get_active_config_hash(): self.source_update_flag.set() def do_local_update(self) -> None: old_update_time = self.get_local_update_time() if not os.path.exists(UPDATER_DIR): os.makedirs(UPDATER_DIR) _, time_keeper = tempfile.mkstemp(prefix="time_keeper_", dir=UPDATER_DIR) if self._service.update_config.generates_signatures: output_directory = tempfile.mkdtemp(prefix="update_dir_", dir=UPDATER_DIR) self.log.info("Setup service account.") username = self.ensure_service_account() self.log.info("Create temporary API key.") with temporary_api_key(self.datastore, username) as api_key: self.log.info(f"Connecting to Assemblyline API: {UI_SERVER}") al_client = get_client(UI_SERVER, apikey=(username, api_key), verify=False) # Check if new signatures have been added self.log.info("Check for new signatures.") if al_client.signature.update_available( since=epoch_to_iso(old_update_time) or '', sig_type=self.updater_type)['update_available']: self.log.info("An update is available for download from the datastore") self.log.debug(f"{self.updater_type} update available since {epoch_to_iso(old_update_time) or ''}") extracted_zip = False attempt = 0 # Sometimes a zip file isn't always returned, will affect service's use of signature source. Patience.. while not extracted_zip and attempt < 5: temp_zip_file = os.path.join(output_directory, 'temp.zip') al_client.signature.download( output=temp_zip_file, query=f"type:{self.updater_type} AND (status:NOISY OR status:DEPLOYED)") self.log.debug(f"Downloading update to {temp_zip_file}") if os.path.exists(temp_zip_file) and os.path.getsize(temp_zip_file) > 0: self.log.debug(f"File type ({os.path.getsize(temp_zip_file)}B): {zip_ident(temp_zip_file, 'unknown')}") try: with ZipFile(temp_zip_file, 'r') as zip_f: zip_f.extractall(output_directory) extracted_zip = True self.log.info("Zip extracted.") except BadZipFile: attempt += 1 self.log.warning(f"[{attempt}/5] Bad zip. Trying again after 30s...") time.sleep(30) except Exception as e: self.log.error(f'Problem while extracting signatures to disk: {e}') break os.remove(temp_zip_file) if extracted_zip: self.log.info("New ruleset successfully downloaded and ready to use") self.serve_directory(output_directory, time_keeper) else: self.log.error("Signatures aren't saved to disk.") shutil.rmtree(output_directory, ignore_errors=True) if os.path.exists(time_keeper): os.unlink(time_keeper) else: self.log.info("No signature updates available.") shutil.rmtree(output_directory, ignore_errors=True) if os.path.exists(time_keeper): os.unlink(time_keeper) else: output_directory = self.prepare_output_directory() self.serve_directory(output_directory, time_keeper) def do_source_update(self, service: Service) -> None: self.log.info(f"Connecting to Assemblyline API: {UI_SERVER}...") run_time = time.time() username = self.ensure_service_account() with temporary_api_key(self.datastore, username) as api_key: with tempfile.TemporaryDirectory() as update_dir: al_client = get_client(UI_SERVER, apikey=(username, api_key), verify=False) old_update_time = self.get_source_update_time() self.log.info("Connected!") # Parse updater configuration previous_hashes: dict[str, dict[str, str]] = self.get_source_extra() sources: dict[str, UpdateSource] = {_s['name']: _s for _s in service.update_config.sources} files_sha256: dict[str, dict[str, str]] = {} # Go through each source and download file for source_name, source_obj in sources.items(): source = source_obj.as_primitives() uri: str = source['uri'] default_classification = source.get('default_classification', classification.UNRESTRICTED) try: # Pull sources from external locations (method depends on the URL) files = git_clone_repo(source, old_update_time, self.default_pattern, self.log, update_dir) \ if uri.endswith('.git') else url_download(source, old_update_time, self.log, update_dir) # Add to collection of sources for caching purposes self.log.info(f"Found new {self.updater_type} rule files to process for {source_name}!") validated_files = list() for file, sha256 in files: files_sha256.setdefault(source_name, {}) if previous_hashes.get(source_name, {}).get(file, None) != sha256 and self.is_valid(file): files_sha256[source_name][file] = sha256 validated_files.append((file, sha256)) # Import into Assemblyline self.import_update(validated_files, al_client, source_name, default_classification) except SkipSource: # This source hasn't changed, no need to re-import into Assemblyline self.log.info(f'No new {self.updater_type} rule files to process for {source_name}') if source_name in previous_hashes: files_sha256[source_name] = previous_hashes[source_name] continue self.set_source_update_time(run_time) self.set_source_extra(files_sha256) self.set_active_config_hash(self.config_hash(service)) self.local_update_flag.set() # Define to determine if file is a valid signature file def is_valid(self, file_path) -> bool: return True # Define how your source update gets imported into Assemblyline def import_update(self, files_sha256: List[Tuple[str, str]], client: Client, source_name: str, default_classification=None): raise NotImplementedError() # Define how to prepare the output directory before being served, must return the path of the directory to serve. def prepare_output_directory(self) -> str: output_directory = tempfile.mkdtemp() shutil.copytree(self.latest_updates_dir, output_directory, dirs_exist_ok=True) return output_directory def _run_source_updates(self): # Wait until basic data is loaded while self._service is None and self.sleep(1): pass if not self._service: return self.log.info("Service info loaded") try: self.log.info("Checking for in cluster update cache") self.do_local_update() self._service_stage_hash.set(SERVICE_NAME, ServiceStage.Running) self.event_sender.send(SERVICE_NAME, {'operation': Operation.Modified, 'name': SERVICE_NAME}) except Exception: self.log.exception('An error occurred loading cached update files. Continuing.') self.local_update_start.set() # Go into a loop running the update whenever triggered or its time to while self.running: # Stringify and hash the the current update configuration service = self._service update_interval = service.update_config.update_interval_seconds # Is it time to update yet? if time.time() - self.get_source_update_time() < update_interval and not self.source_update_flag.is_set(): self.source_update_flag.wait(60) continue if not self.running: return # With temp directory self.source_update_flag.clear() self.log.info('Calling update function...') # Run update function # noinspection PyBroadException try: self.do_source_update(service=service) except Exception: self.log.exception('An error occurred running the update. Will retry...') self.source_update_flag.set() self.sleep(60) continue def serve_directory(self, new_directory: str, new_time: str): self.log.info("Update finished with new data.") new_tar = '' try: # Tar update directory _, new_tar = tempfile.mkstemp(prefix="signatures_", dir=UPDATER_DIR, suffix='.tar.bz2') tar_handle = tarfile.open(new_tar, 'w:bz2') tar_handle.add(new_directory, '/') tar_handle.close() # swap update directory with old one self._update_dir, new_directory = new_directory, self._update_dir self._update_tar, new_tar = new_tar, self._update_tar self._time_keeper, new_time = new_time, self._time_keeper self.log.info(f"Now serving: {self._update_dir} and {self._update_tar} ({self.get_local_update_time()})") finally: if new_tar and os.path.exists(new_tar): self.log.info(f"Remove old tar file: {new_tar}") time.sleep(3) os.unlink(new_tar) if new_directory and os.path.exists(new_directory): self.log.info(f"Remove old directory: {new_directory}") shutil.rmtree(new_directory, ignore_errors=True) if new_time and os.path.exists(new_time): self.log.info(f"Remove old time keeper file: {new_time}") os.unlink(new_time) def _run_local_updates(self): # Wait until basic data is loaded while self._service is None and self.sleep(1): pass if not self._service: return self.local_update_start.wait() # Go into a loop running the update whenever triggered or its time to while self.running: # Is it time to update yet? if not self.local_update_flag.is_set(): self.local_update_flag.wait(60) continue if not self.running: return self.local_update_flag.clear() # With temp directory self.log.info('Updating local files...') # Run update function # noinspection PyBroadException try: self.do_local_update() if self._service_stage_hash.get(SERVICE_NAME) == ServiceStage.Update: self._service_stage_hash.set(SERVICE_NAME, ServiceStage.Running) self.event_sender.send(SERVICE_NAME, {'operation': Operation.Modified, 'name': SERVICE_NAME}) except Exception: self.log.exception('An error occurred finding new local files. Will retry...') self.local_update_flag.set() self.sleep(60) continue def ensure_service_account(self): """Check that the update service account exists, if it doesn't, create it.""" uname = 'update_service_account' if self.datastore.user.get_if_exists(uname): return uname user_data = User({ "agrees_with_tos": "NOW", "classification": "RESTRICTED", "name": "Update Account", "password": get_password_hash(''.join(random.choices(string.ascii_letters, k=20))), "uname": uname, "type": ["signature_importer"] }) self.datastore.user.save(uname, user_data) self.datastore.user_settings.save(uname, UserSettings()) return uname
class TaskingClient: """A helper class to simplify tasking for privileged services and service-server. This tool helps take care of interactions between the filestore, datastore, dispatcher, and any sources of files to be processed. """ def __init__(self, datastore: AssemblylineDatastore = None, filestore: FileStore = None, config=None, redis=None, redis_persist=None, identify=None): self.log = logging.getLogger('assemblyline.tasking_client') self.config = config or forge.CachedObject(forge.get_config) self.datastore = datastore or forge.get_datastore(self.config) self.dispatch_client = DispatchClient(self.datastore, redis=redis, redis_persist=redis_persist) self.event_sender = EventSender('changes.services', redis) self.filestore = filestore or forge.get_filestore(self.config) self.heuristic_handler = HeuristicHandler(self.datastore) self.heuristics = { h.heur_id: h for h in self.datastore.list_all_heuristics() } self.status_table = ExpiringHash(SERVICE_STATE_HASH, ttl=60 * 30, host=redis) self.tag_safelister = forge.CachedObject(forge.get_tag_safelister, kwargs=dict( log=self.log, config=config, datastore=self.datastore), refresh=300) if identify: self.cleanup = False else: self.cleanup = True self.identify = identify or forge.get_identify( config=self.config, datastore=self.datastore, use_cache=True) def __enter__(self): return self def __exit__(self, *_): self.stop() def stop(self): if self.cleanup: self.identify.stop() @elasticapm.capture_span(span_type='tasking_client') def upload_file(self, file_path, classification, ttl, is_section_image, expected_sha256=None): # Identify the file info of the uploaded file file_info = self.identify.fileinfo(file_path) # Validate SHA256 of the uploaded file if expected_sha256 is None or expected_sha256 == file_info['sha256']: file_info['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) file_info['classification'] = classification if ttl: file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60) else: file_info['expiry_ts'] = None # Update the datastore with the uploaded file self.datastore.save_or_freshen_file( file_info['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=is_section_image) # Upload file to the filestore (upload already checks if the file exists) self.filestore.upload(file_path, file_info['sha256']) else: raise TaskingClientException( "Uploaded file does not match expected file hash. " f"[{file_info['sha256']} != {expected_sha256}]") # Service @elasticapm.capture_span(span_type='tasking_client') def register_service(self, service_data, log_prefix=""): keep_alive = True try: # Get heuristics list heuristics = service_data.pop('heuristics', None) # Patch update_channel, registry_type before Service registration object creation service_data['update_channel'] = service_data.get( 'update_channel', self.config.services.preferred_update_channel) service_data['docker_config']['registry_type'] = service_data['docker_config'] \ .get('registry_type', self.config.services.preferred_registry_type) service_data['privileged'] = service_data.get( 'privileged', self.config.services.prefer_service_privileged) for dep in service_data.get('dependencies', {}).values(): dep['container']['registry_type'] = dep.get( 'registry_type', self.config.services.preferred_registry_type) # Pop unused registration service_data for x in ['file_required', 'tool_version']: service_data.pop(x, None) # Create Service registration object service = Service(service_data) # Fix service version, we don't need to see the stable label service.version = service.version.replace('stable', '') # Save service if it doesn't already exist if not self.datastore.service.exists( f'{service.name}_{service.version}'): self.datastore.service.save( f'{service.name}_{service.version}', service) self.datastore.service.commit() self.log.info(f"{log_prefix}{service.name} registered") keep_alive = False # Save service delta if it doesn't already exist if not self.datastore.service_delta.exists(service.name): self.datastore.service_delta.save(service.name, {'version': service.version}) self.datastore.service_delta.commit() self.log.info(f"{log_prefix}{service.name} " f"version ({service.version}) registered") new_heuristics = [] if heuristics: plan = self.datastore.heuristic.get_bulk_plan() for index, heuristic in enumerate(heuristics): heuristic_id = f'#{index}' # Set heuristic id to it's position in the list for logging purposes try: # Append service name to heuristic ID heuristic[ 'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}" # Attack_id field is now a list, make it a list if we receive otherwise attack_id = heuristic.get('attack_id', None) if isinstance(attack_id, str): heuristic['attack_id'] = [attack_id] heuristic = Heuristic(heuristic) heuristic_id = heuristic.heur_id plan.add_upsert_operation(heuristic_id, heuristic) except Exception as e: msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}" self.log.exception(f"{log_prefix}{msg}") raise ValueError(msg) for item in self.datastore.heuristic.bulk(plan)['items']: if item['update']['result'] != "noop": new_heuristics.append(item['update']['_id']) self.log.info( f"{log_prefix}{service.name} " f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}" ) self.datastore.heuristic.commit() service_config = self.datastore.get_service_with_delta( service.name, as_obj=False) # Notify components watching for service config changes self.event_sender.send(service.name, { 'operation': Operation.Added, 'name': service.name }) except ValueError as e: # Catch errors when building Service or Heuristic model(s) raise e return dict(keep_alive=keep_alive, new_heuristics=new_heuristics, service_config=service_config or dict()) # Task @elasticapm.capture_span(span_type='tasking_client') def get_task(self, client_id, service_name, service_version, service_tool_version, metric_factory, status_expiry=None, timeout=30): if status_expiry is None: status_expiry = time.time() + timeout cache_found = False try: service_data = self.dispatch_client.service_data[service_name] except KeyError: raise ServiceMissingException( "The service you're asking task for does not exist, try later", 404) # Set the service status to Idle since we will be waiting for a task self.status_table.set( client_id, (service_name, ServiceStatus.Idle, status_expiry)) # Getting a new task task = self.dispatch_client.request_work(client_id, service_name, service_version, timeout=timeout) if not task: # We've reached the timeout and no task found in service queue return None, False # We've got a task to process, consider us busy self.status_table.set(client_id, (service_name, ServiceStatus.Running, time.time() + service_data.timeout)) metric_factory.increment('execute') result_key = Result.help_build_key( sha256=task.fileinfo.sha256, service_name=service_name, service_version=service_version, service_tool_version=service_tool_version, is_empty=False, task=task) # If we are allowed, try to see if the result has been cached if not task.ignore_cache and not service_data.disable_cache: # Checking for previous results for this key result = self.datastore.result.get_if_exists(result_key) if result: metric_factory.increment('cache_hit') if result.result.score: metric_factory.increment('scored') else: metric_factory.increment('not_scored') result.archive_ts = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) self.dispatch_client.service_finished(task.sid, result_key, result) cache_found = True if not cache_found: # Checking for previous empty results for this key result = self.datastore.emptyresult.get_if_exists( f"{result_key}.e") if result: metric_factory.increment('cache_hit') metric_factory.increment('not_scored') result = self.datastore.create_empty_result_from_key( result_key) self.dispatch_client.service_finished( task.sid, f"{result_key}.e", result) cache_found = True if not cache_found: metric_factory.increment('cache_miss') else: metric_factory.increment('cache_skipped') if not cache_found: # No luck with the cache, lets dispatch the task to a client return task.as_primitives(), False return None, True @elasticapm.capture_span(span_type='tasking_client') def task_finished(self, service_task, client_id, service_name, metric_factory): exec_time = service_task.get('exec_time') try: task = ServiceTask(service_task['task']) if 'result' in service_task: # Task created a result missing_files = self._handle_task_result( exec_time, task, service_task['result'], client_id, service_name, service_task['freshen'], metric_factory) if missing_files: return dict(success=False, missing_files=missing_files) return dict(success=True) elif 'error' in service_task: # Task created an error error = service_task['error'] self._handle_task_error(exec_time, task, error, client_id, service_name, metric_factory) return dict(success=True) else: return None except ValueError as e: # Catch errors when building Task or Result model raise e @elasticapm.capture_span(span_type='tasking_client') def _handle_task_result(self, exec_time: int, task: ServiceTask, result: Dict[str, Any], client_id, service_name, freshen: bool, metric_factory): def freshen_file(file_info_list, item): file_info = file_info_list.get(item['sha256'], None) if file_info is None or not self.filestore.exists(item['sha256']): return True else: file_info['archive_ts'] = archive_ts file_info['expiry_ts'] = expiry_ts file_info['classification'] = item['classification'] self.datastore.save_or_freshen_file( item['sha256'], file_info, file_info['expiry_ts'], file_info['classification'], is_section_image=item.get('is_section_image', False)) return False archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60) else: expiry_ts = None # Check if all files are in the filestore if freshen: missing_files = [] hashes = list( set([ f['sha256'] for f in result['response']['extracted'] + result['response']['supplementary'] ])) file_infos = self.datastore.file.multiget(hashes, as_obj=False, error_on_missing=False) with elasticapm.capture_span( name="handle_task_result.freshen_files", span_type="tasking_client"): with concurrent.futures.ThreadPoolExecutor( max_workers=5) as executor: res = { f['sha256']: executor.submit(freshen_file, file_infos, f) for f in result['response']['extracted'] + result['response']['supplementary'] } for k, v in res.items(): if v.result(): missing_files.append(k) if missing_files: return missing_files # Add scores to the heuristics, if any section set a heuristic with elasticapm.capture_span( name="handle_task_result.process_heuristics", span_type="tasking_client"): total_score = 0 for section in result['result']['sections']: zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True) section['tags'] = flatten(section['tags']) if section.get('heuristic'): heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}" section['heuristic']['heur_id'] = heur_id try: section[ 'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic( section['heuristic'], self.heuristics, zeroize_on_sig_safe) for tag in new_tags: section['tags'].setdefault(tag[0], []) if tag[1] not in section['tags'][tag[0]]: section['tags'][tag[0]].append(tag[1]) total_score += section['heuristic']['score'] except InvalidHeuristicException: section['heuristic'] = None # Update the total score of the result result['result']['score'] = total_score # Add timestamps for creation, archive and expiry result['created'] = now_as_iso() result['archive_ts'] = archive_ts result['expiry_ts'] = expiry_ts # Pop the temporary submission data temp_submission_data = result.pop('temp_submission_data', None) if temp_submission_data: old_submission_data = { row.name: row.value for row in task.temporary_submission_data } temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in old_submission_data or v != old_submission_data[k] } big_temp_data = { k: len(str(v)) for k, v in temp_submission_data.items() if len(str(v)) > self.config.submission.max_temp_data_length } if big_temp_data: big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()] self.log.warning( f"[{task.sid}] The following temporary submission keys where ignored because they are " "bigger then the maximum data size allowed " f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}" ) temp_submission_data = { k: v for k, v in temp_submission_data.items() if k not in big_temp_data } # Process the tag values with elasticapm.capture_span(name="handle_task_result.process_tags", span_type="tasking_client"): for section in result['result']['sections']: # Perform tag safelisting tags, safelisted_tags = self.tag_safelister.get_validated_tag_map( section['tags']) section['tags'] = unflatten(tags) section['safelisted_tags'] = safelisted_tags section['tags'], dropped = construct_safe( Tagging, section.get('tags', {})) # Set section score to zero and lower total score if service is set to zeroize score # and all tags were safelisted if section.pop('zeroize_on_tag_safe', False) and \ section.get('heuristic') and \ len(tags) == 0 and \ len(safelisted_tags) != 0: result['result']['score'] -= section['heuristic']['score'] section['heuristic']['score'] = 0 if dropped: self.log.warning( f"[{task.sid}] Invalid tag data from {service_name}: {dropped}" ) result = Result(result) result_key = result.build_key( service_tool_version=result.response.service_tool_version, task=task) self.dispatch_client.service_finished(task.sid, result_key, result, temp_submission_data) # Metrics if result.result.score > 0: metric_factory.increment('scored') else: metric_factory.increment('not_scored') self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}" ) self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5)) @elasticapm.capture_span(span_type='tasking_client') def _handle_task_error(self, exec_time: int, task: ServiceTask, error: Dict[str, Any], client_id, service_name, metric_factory) -> None: self.log.info( f"[{task.sid}] {client_id} - {service_name} " f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}" ) # Add timestamps for creation, archive and expiry error['created'] = now_as_iso() error['archive_ts'] = now_as_iso( self.config.datastore.ilm.days_until_archive * 24 * 60 * 60) if task.ttl: error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60) error = Error(error) error_key = error.build_key( service_tool_version=error.response.service_tool_version, task=task) self.dispatch_client.service_failed(task.sid, error_key, error) # Metrics if error.response.status == 'FAIL_RECOVERABLE': metric_factory.increment('fail_recoverable') else: metric_factory.increment('fail_nonrecoverable') self.status_table.set( client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
from assemblyline.odm.models.service import SIGNATURE_DELIMITERS from assemblyline.remote.datatypes import get_client from assemblyline.remote.datatypes.hash import Hash from assemblyline.remote.datatypes.lock import Lock from assemblyline.remote.datatypes.events import EventSender from assemblyline_ui.api.base import api_login, make_api_response, make_file_response, make_subapi_blueprint from assemblyline_ui.config import LOGGER, SERVICE_LIST, STORAGE, config, CLASSIFICATION as Classification SUB_API = 'signature' signature_api = make_subapi_blueprint(SUB_API, api_version=4) signature_api._doc = "Perform operations on signatures" DEFAULT_CACHE_TTL = 24 * 60 * 60 # 1 Day event_sender = EventSender('changes.signatures', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) service_event_sender = EventSender('changes.services', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) def _reset_service_updates(signature_type): service_updates = Hash( 'service-updates', get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, ))
class ServiceUpdater(ThreadedCoreBase): def __init__(self, redis_persist=None, redis=None, logger=None, datastore=None): super().__init__('assemblyline.service.updater', logger=logger, datastore=datastore, redis_persist=redis_persist, redis=redis) self.container_update: Hash[dict[str, Any]] = Hash('container-update', self.redis_persist) self.latest_service_tags: Hash[dict[str, str]] = Hash('service-tags', self.redis_persist) self.service_events = EventSender('changes.services', host=self.redis) self.incompatible_services = set() self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE: extra_labels = {} if self.config.core.scaler.additional_labels: extra_labels = {k: v for k, v in (_l.split("=") for _l in self.config.core.scaler.additional_labels)} self.controller = KubernetesUpdateInterface(prefix='alsvc_', namespace=NAMESPACE, priority_class='al-core-priority', extra_labels=extra_labels, log_level=self.config.logging.log_level) else: self.controller = DockerUpdateInterface(log_level=self.config.logging.log_level) def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Incompatible: self.incompatible_services.add(data.name) def container_updates(self): """Go through the list of services and check what are the latest tags for it""" while self.running: self.log.info("[CU] Updating all services marked for update...") # Update function for services def update_service(service_name: str, update_data: dict) -> str: self.log.info(f"[CU] Service {service_name} is being updated to version {update_data['latest_tag']}...") # Load authentication params username = None password = None auth = update_data['auth'] or {} if auth: username = auth.get('username', None) password = auth.get('password', None) latest_tag = update_data['latest_tag'].replace('stable', '') service_key = f"{service_name}_{latest_tag}" try: self.controller.launch( name=service_name, docker_config=DockerConfig(dict( allow_internet_access=True, registry_username=username, registry_password=password, cpu_cores=1, environment=[], image=update_data['image'], ports=[] )), mounts=[], env={ "SERVICE_TAG": update_data['latest_tag'], "REGISTER_ONLY": 'true', "PRIVILEGED": 'true', }, blocking=True ) except Exception as e: self.log.error( f"[CU] Service {service_name} has failed to update. Update procedure cancelled... [{str(e)}]") return service_key # Start up updates for services in parallel update_threads = [] with ThreadPoolExecutor() as service_updates_exec: for service_name, update_data in self.container_update.items().items(): update_threads.append(service_updates_exec.submit(update_service, service_name, update_data)) # Once all threads are completed, check the status of the updates for thread in update_threads: service_key = thread.result() service_name, latest_tag = service_key.split("_") if self.datastore.service.get_if_exists(service_key): operations = [(self.datastore.service_delta.UPDATE_SET, 'version', latest_tag)] # Check if a service waas previously disabled and re-enable it if service_name in self.incompatible_services: self.incompatible_services.remove(service_name) operations.append((self.datastore.service_delta.UPDATE_SET, 'enabled', True)) if self.datastore.service_delta.update(service_name, operations): # Update completed, cleanup self.service_events.send(service_name, { 'operation': Operation.Modified, 'name': service_name }) self.log.info(f"[CU] Service {service_name} update successful!") else: self.log.error(f"[CU] Service {service_name} has failed to update because it cannot set " f"{latest_tag} as the new version. Update procedure cancelled...") else: self.log.error(f"[CU] Service {service_name} has failed to update because resulting " f"service key ({service_key}) does not exist. Update procedure cancelled...") self.container_update.pop(service_name) # Clear out any old dead containers self.controller.cleanup_stale() self.log.info(f"[CU] Done updating services, waiting {UPDATE_CHECK_INTERVAL} seconds for next update...") time.sleep(UPDATE_CHECK_INTERVAL) def container_versions(self): """Go through the list of services and check what are the latest tags for it""" while self.running: self.log.info("[CV] Checking for new versions of all service containers...") existing_services = set(self.container_update.keys()) | set(self.latest_service_tags.keys()) discovered_services: list[str] = [] for service in self.datastore.list_all_services(full=True): discovered_services.append(service.name) image_name, tag_name, auth = get_latest_tag_for_service(service, self.config, self.log, prefix="[CV] ") self.latest_service_tags.set(service.name, {'auth': auth, 'image': image_name, service.update_channel: tag_name}) # Remove services we have locally or in redis that have been deleted from the database for stray_service in existing_services - set(discovered_services): self.log.info(f"[CV] Service updates disabled for {stray_service}") self._service_stage_hash.pop(stray_service) self.container_update.pop(stray_service) self.latest_service_tags.pop(stray_service) self.log.info("[CV] Done checking for new container versions, " f"waiting {CONTAINER_CHECK_INTERVAL} seconds for next run...") time.sleep(CONTAINER_CHECK_INTERVAL) def try_run(self): # Load and maintain threads threads = { 'Container version check': self.container_versions, 'Container updates': self.container_updates } self.maintain_threads(threads)
from assemblyline.common.str_utils import safe_str from assemblyline.odm.models.tagging import Tagging from assemblyline.remote.datatypes.events import EventSender from assemblyline_ui.config import STORAGE, UI_MESSAGING, config from assemblyline_ui.api.base import api_login, make_api_response, make_subapi_blueprint SUB_API = 'system' system_api = make_subapi_blueprint(SUB_API, api_version=4) system_api._doc = "Perform system actions" ADMIN_FILE_TTL = 60 * 60 * 24 * 365 * 100 # Just keep the file for 100 years... al_re = re.compile(r"^[a-z]+(?:/[a-z0-9\-.+]+)+$") constants = forge.get_constants() event_sender = EventSender('system', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) @system_api.route("/system_message/", methods=["DELETE"]) @api_login(require_type=['admin'], required_priv=['W']) def clear_system_message(**_): """ Clear the current system message Variables: None Arguments: None
def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): super().__init__('assemblyline.scaler', config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count_lock = threading.Lock() self.error_count: dict[str, list[float]] = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30 * 60) self.service_event_sender = EventSender('changes.services', host=self.redis) self.service_change_watcher = EventWatcher( self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) core_env: dict[str, str] = {} # If we have privileged services, we must be able to pass the necessary environment variables for them to # function properly. for secret in re.findall( r'\${\w+}', open('/etc/assemblyline/config.yml', 'r').read()) + ['UI_SERVER']: env_name = secret.strip("${}") core_env[env_name] = os.environ[env_name] labels = { 'app': 'assemblyline', 'section': 'service', 'privilege': 'service' } if self.config.core.scaler.additional_labels: labels.update({ k: v for k, v in ( _l.split("=") for _l in self.config.core.scaler.additional_labels) }) if KUBERNETES_AL_CONFIG: self.log.info( f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}" ) self.controller = KubernetesController( logger=self.log, prefix='alsvc_', labels=labels, namespace=NAMESPACE, priority='al-service-priority', cpu_reservation=self.config.services.cpu_reservation, log_level=self.config.logging.log_level, core_env=core_env) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_CONFIGMAP: self.controller.config_mount( 'classification-config', config_map=CLASSIFICATION_CONFIGMAP, key=CLASSIFICATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/classification.yml') if CONFIGURATION_CONFIGMAP: self.controller.core_config_mount( 'assemblyline-config', config_map=CONFIGURATION_CONFIGMAP, key=CONFIGURATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/config.yml') else: self.log.info("Loading Docker cluster interface.") self.controller = DockerController( logger=self.log, prefix=NAMESPACE, labels=labels, log_level=self.config.logging.log_level, core_env=core_env) self._service_stage_hash.delete() if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME: self.controller.core_mounts.append( (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/')) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'), 'w') as handle: yaml.dump(self.config.as_primitives(), handle) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'classification.yml'), 'w') as handle: yaml.dump(get_classification().original_definition, handle) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_HOST_PATH: self.controller.global_mounts.append( (CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) # Information about services self.profiles: dict[str, ServiceProfile] = {} self.profiles_lock = threading.RLock() # Prepare a single threaded scheduler self.state = collection.Collection( period=self.config.core.metrics.export_interval) self.stopping = threading.Event() self.main_loop_exit = threading.Event() # Load the APM connection if any self.apm_client = None if self.config.core.metrics.apm_server.server_url: elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="scaler")
class ScalerServer(ThreadedCoreBase): def __init__(self, config=None, datastore=None, redis=None, redis_persist=None): super().__init__('assemblyline.scaler', config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE, host=self.redis_persist) self.error_count_lock = threading.Lock() self.error_count: dict[str, list[float]] = {} self.status_table = ExpiringHash(SERVICE_STATE_HASH, host=self.redis, ttl=30 * 60) self.service_event_sender = EventSender('changes.services', host=self.redis) self.service_change_watcher = EventWatcher( self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) core_env: dict[str, str] = {} # If we have privileged services, we must be able to pass the necessary environment variables for them to # function properly. for secret in re.findall( r'\${\w+}', open('/etc/assemblyline/config.yml', 'r').read()) + ['UI_SERVER']: env_name = secret.strip("${}") core_env[env_name] = os.environ[env_name] labels = { 'app': 'assemblyline', 'section': 'service', 'privilege': 'service' } if self.config.core.scaler.additional_labels: labels.update({ k: v for k, v in ( _l.split("=") for _l in self.config.core.scaler.additional_labels) }) if KUBERNETES_AL_CONFIG: self.log.info( f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}" ) self.controller = KubernetesController( logger=self.log, prefix='alsvc_', labels=labels, namespace=NAMESPACE, priority='al-service-priority', cpu_reservation=self.config.services.cpu_reservation, log_level=self.config.logging.log_level, core_env=core_env) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_CONFIGMAP: self.controller.config_mount( 'classification-config', config_map=CLASSIFICATION_CONFIGMAP, key=CLASSIFICATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/classification.yml') if CONFIGURATION_CONFIGMAP: self.controller.core_config_mount( 'assemblyline-config', config_map=CONFIGURATION_CONFIGMAP, key=CONFIGURATION_CONFIGMAP_KEY, target_path='/etc/assemblyline/config.yml') else: self.log.info("Loading Docker cluster interface.") self.controller = DockerController( logger=self.log, prefix=NAMESPACE, labels=labels, log_level=self.config.logging.log_level, core_env=core_env) self._service_stage_hash.delete() if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME: self.controller.core_mounts.append( (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/')) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'), 'w') as handle: yaml.dump(self.config.as_primitives(), handle) with open( os.path.join(DOCKER_CONFIGURATION_PATH, 'classification.yml'), 'w') as handle: yaml.dump(get_classification().original_definition, handle) # If we know where to find it, mount the classification into the service containers if CLASSIFICATION_HOST_PATH: self.controller.global_mounts.append( (CLASSIFICATION_HOST_PATH, '/etc/assemblyline/classification.yml')) # Information about services self.profiles: dict[str, ServiceProfile] = {} self.profiles_lock = threading.RLock() # Prepare a single threaded scheduler self.state = collection.Collection( period=self.config.core.metrics.export_interval) self.stopping = threading.Event() self.main_loop_exit = threading.Event() # Load the APM connection if any self.apm_client = None if self.config.core.metrics.apm_server.server_url: elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="scaler") def log_crashes(self, fn): @functools.wraps(fn) def with_logs(*args, **kwargs): # noinspection PyBroadException try: fn(*args, **kwargs) except ServiceControlError as error: self.log.exception( f"Error while managing service: {error.service_name}") self.handle_service_error(error.service_name) except Exception: self.log.exception(f'Crash in scaler: {fn.__name__}') return with_logs @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def add_service(self, profile: ServiceProfile): # We need to hold the lock the whole time we add the service, # we don't want the scaling thread trying to adjust the scale of a # deployment we haven't added to the system yet with self.profiles_lock: profile.desired_instances = max( self.controller.get_target(profile.name), profile.min_instances) profile.running_instances = profile.desired_instances profile.target_instances = profile.desired_instances self.log.debug( f'Starting service {profile.name} with a target of {profile.desired_instances}' ) profile.last_update = time.time() self.profiles[profile.name] = profile self.controller.add_profile(profile, scale=profile.desired_instances) def try_run(self): self.service_change_watcher.start() self.maintain_threads({ 'Log Container Events': self.log_container_events, 'Process Timeouts': self.process_timeouts, 'Service Configuration Sync': self.sync_services, 'Service Adjuster': self.update_scaling, 'Import Metrics': self.sync_metrics, 'Export Metrics': self.export_metrics, }) def stop(self): super().stop() self.service_change_watcher.stop() self.controller.stop() def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Removed: self.log.info( f'Service appears to be deleted, removing {data.name}') stage = self.get_service_stage(data.name) self.stop_service(data.name, stage) elif data.operation == Operation.Incompatible: return else: self._sync_service(self.datastore.get_service_with_delta( data.name)) def sync_services(self): while self.running: with apm_span(self.apm_client, 'sync_services'): with self.profiles_lock: current_services = set(self.profiles.keys()) discovered_services: list[str] = [] # Get all the service data for service in self.datastore.list_all_services(full=True): self._sync_service(service) discovered_services.append(service.name) # Find any services we have running, that are no longer in the database and remove them for stray_service in current_services - set( discovered_services): self.log.info( f'Service appears to be deleted, removing stray {stray_service}' ) stage = self.get_service_stage(stray_service) self.stop_service(stray_service, stage) self.sleep(SERVICE_SYNC_INTERVAL) def _sync_service(self, service: Service): name = service.name stage = self.get_service_stage(service.name) default_settings = self.config.core.scaler.service_defaults image_variables: defaultdict[str, str] = defaultdict(str) image_variables.update(self.config.services.image_variables) def prepare_container(docker_config: DockerConfig) -> DockerConfig: docker_config.image = Template( docker_config.image).safe_substitute(image_variables) set_keys = set(var.name for var in docker_config.environment) for var in default_settings.environment: if var.name not in set_keys: docker_config.environment.append(var) return docker_config # noinspection PyBroadException try: def disable_incompatible_service(): service.enabled = False if self.datastore.service_delta.update(service.name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service.name, { 'operation': Operation.Incompatible, 'name': service.name }) # Check if service considered compatible to run on Assemblyline? system_spec = f'{FRAMEWORK_VERSION}.{SYSTEM_VERSION}' if not service.version.startswith(system_spec): # If FW and SYS version don't prefix in the service version, we can't guarantee the service is compatible # Disable and treat it as incompatible due to service version. self.log.warning( "Disabling service with incompatible version. " f"[{service.version} != '{system_spec}.X.{service.update_channel}Y']." ) disable_incompatible_service() elif service.update_config and service.update_config.wait_for_update and not service.update_config.sources: # All signatures sources from a signature-dependent service was removed # Disable and treat it as incompatible due to service configuration relative to source management self.log.warning( "Disabling service with incompatible service configuration. " "Signature-dependent service has no signature sources.") disable_incompatible_service() if not service.enabled: self.stop_service(service.name, stage) return # Build the docker config for the dependencies. For now the dependency blob values # aren't set for the change key going to kubernetes because everything about # the dependency config should be captured in change key that the function generates # internally. A change key is set for the service deployment as that includes # things like the submission params dependency_config: dict[str, Any] = {} dependency_blobs: dict[str, str] = {} for _n, dependency in service.dependencies.items(): dependency.container = prepare_container(dependency.container) dependency_config[_n] = dependency dep_hash = get_id_from_data(dependency, length=16) dependency_blobs[ _n] = f"dh={dep_hash}v={service.version}p={service.privileged}" # Check if the service dependencies have been deployed. dependency_keys = [] updater_ready = stage == ServiceStage.Running if service.update_config: for _n, dependency in dependency_config.items(): key = self.controller.stateful_container_key( service.name, _n, dependency, '') if key: dependency_keys.append(_n + key) else: updater_ready = False # If stage is not set to running or a dependency container is missing start the setup process if not updater_ready: self.log.info(f'Preparing environment for {service.name}') # Move to the next service stage (do this first because the container we are starting may care) if service.update_config and service.update_config.wait_for_update: self._service_stage_hash.set(name, ServiceStage.Update) stage = ServiceStage.Update else: self._service_stage_hash.set(name, ServiceStage.Running) stage = ServiceStage.Running # Enable this service's dependencies before trying to launch the service containers dependency_internet = [ (name, dependency.container.allow_internet_access) for name, dependency in dependency_config.items() ] self.controller.prepare_network( service.name, service.docker_config.allow_internet_access, dependency_internet) for _n, dependency in dependency_config.items(): self.log.info(f'Launching {service.name} dependency {_n}') self.controller.start_stateful_container( service_name=service.name, container_name=_n, spec=dependency, labels={'dependency_for': service.name}, change_key=dependency_blobs.get(_n, '')) # If the conditions for running are met deploy or update service containers if stage == ServiceStage.Running: # Build the docker config for the service, we are going to either create it or # update it so we need to know what the current configuration is either way docker_config = prepare_container(service.docker_config) # Compute a blob of service properties not include in the docker config, that # should still result in a service being restarted when changed cfg_items = get_recursive_sorted_tuples(service.config) dep_keys = ''.join(sorted(dependency_keys)) config_blob = ( f"c={cfg_items}sp={service.submission_params}" f"dk={dep_keys}p={service.privileged}d={docker_config}") # Add the service to the list of services being scaled with self.profiles_lock: if name not in self.profiles: self.log.info( f"Adding " f"{f'privileged {service.name}' if service.privileged else service.name}" " to scaling") self.add_service( ServiceProfile( name=name, min_instances=default_settings.min_instances, growth=default_settings.growth, shrink=default_settings.shrink, config_blob=config_blob, dependency_blobs=dependency_blobs, backlog=default_settings.backlog, max_instances=service.licence_count, container_config=docker_config, queue=get_service_queue(name, self.redis), # Give service an extra 30 seconds to upload results shutdown_seconds=service.timeout + 30, privileged=service.privileged)) # Update RAM, CPU, licence requirements for running services else: profile = self.profiles[name] profile.max_instances = service.licence_count profile.privileged = service.privileged for dependency_name, dependency_blob in dependency_blobs.items( ): if profile.dependency_blobs[ dependency_name] != dependency_blob: self.log.info( f"Updating deployment information for {name}/{dependency_name}" ) profile.dependency_blobs[ dependency_name] = dependency_blob self.controller.start_stateful_container( service_name=service.name, container_name=dependency_name, spec=dependency_config[dependency_name], labels={'dependency_for': service.name}, change_key=dependency_blob) if profile.config_blob != config_blob: self.log.info( f"Updating deployment information for {name}") profile.container_config = docker_config profile.config_blob = config_blob self.controller.restart(profile) self.log.info( f"Deployment information for {name} replaced") except Exception: self.log.exception( f"Error applying service settings from: {service.name}") self.handle_service_error(service.name) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def stop_service(self, name: str, current_stage: ServiceStage): if current_stage != ServiceStage.Off: # Disable this service's dependencies self.controller.stop_containers(labels={'dependency_for': name}) # Mark this service as not running in the shared record self._service_stage_hash.set(name, ServiceStage.Off) # Stop any running disabled services if name in self.profiles or self.controller.get_target(name) > 0: self.log.info(f'Removing {name} from scaling') with self.profiles_lock: self.profiles.pop(name, None) self.controller.set_target(name, 0) def update_scaling(self): """Check if we need to scale any services up or down.""" pool = Pool() while self.sleep(SCALE_INTERVAL): with apm_span(self.apm_client, 'update_scaling'): # Figure out what services are expected to be running and how many with elasticapm.capture_span('read_profiles'): with self.profiles_lock: all_profiles: dict[str, ServiceProfile] = copy.deepcopy( self.profiles) raw_targets = self.controller.get_targets() targets = { _p.name: raw_targets.get(_p.name, 0) for _p in all_profiles.values() } for name, profile in all_profiles.items(): self.log.debug(f'{name}') self.log.debug( f'Instances \t{profile.min_instances} < {profile.desired_instances} | ' f'{targets[name]} < {profile.max_instances}') self.log.debug( f'Pressure \t{profile.shrink_threshold} < ' f'{profile.pressure} < {profile.growth_threshold}') # # 1. Any processes that want to release resources can always be approved first # with pool: for name, profile in all_profiles.items(): if targets[name] > profile.desired_instances: self.log.info( f"{name} wants less resources changing allocation " f"{targets[name]} -> {profile.desired_instances}" ) pool.call(self.controller.set_target, name, profile.desired_instances) targets[name] = profile.desired_instances # # 2. Any processes that aren't reaching their min_instances target must be given # more resources before anyone else is considered. # for name, profile in all_profiles.items(): if targets[name] < profile.min_instances: self.log.info( f"{name} isn't meeting minimum allocation " f"{targets[name]} -> {profile.min_instances}") pool.call(self.controller.set_target, name, profile.min_instances) targets[name] = profile.min_instances # # 3. Try to estimate available resources, and based on some metric grant the # resources to each service that wants them. While this free memory # pool might be spread across many nodes, we are going to treat it like # it is one big one, and let the orchestration layer sort out the details. # # Recalculate the amount of free resources expanding the total quantity by the overallocation free_cpu, total_cpu = self.controller.cpu_info() used_cpu = total_cpu - free_cpu free_cpu = total_cpu * self.config.core.scaler.cpu_overallocation - used_cpu free_memory, total_memory = self.controller.memory_info() used_memory = total_memory - free_memory free_memory = total_memory * self.config.core.scaler.memory_overallocation - used_memory # def trim(prof: list[ServiceProfile]): prof = [ _p for _p in prof if _p.desired_instances > targets[_p.name] ] drop = [ _p for _p in prof if _p.cpu > free_cpu or _p.ram > free_memory ] if drop: summary = {_p.name: (_p.cpu, _p.ram) for _p in drop} self.log.debug( f"Can't make more because not enough resources {summary}" ) prof = [ _p for _p in prof if _p.cpu <= free_cpu and _p.ram <= free_memory ] return prof remaining_profiles: list[ServiceProfile] = trim( list(all_profiles.values())) # The target values up until now should be in sync with the container orchestrator # create a copy, so we can track which ones change in the following loop old_targets = dict(targets) while remaining_profiles: # TODO do we need to add balancing metrics other than 'least running' for this? probably remaining_profiles.sort(key=lambda _p: targets[_p.name]) # Add one for the profile at the bottom free_memory -= remaining_profiles[ 0].container_config.ram_mb free_cpu -= remaining_profiles[ 0].container_config.cpu_cores targets[remaining_profiles[0].name] += 1 # Take out any services that should be happy now remaining_profiles = trim(remaining_profiles) # Apply those adjustments we have made back to the controller with elasticapm.capture_span('write_targets'): with pool: for name, value in targets.items(): if name not in self.profiles: # A service was probably added/removed while we were # in the middle of this function continue self.profiles[name].target_instances = value old = old_targets[name] if value != old: self.log.info( f"Scaling service {name}: {old} -> {value}" ) pool.call(self.controller.set_target, name, value) @elasticapm.capture_span(span_type=APM_SPAN_TYPE) def handle_service_error(self, service_name: str): """Handle an error occurring in the *analysis* service. Errors for core systems should simply be logged, and a best effort to continue made. For analysis services, ignore the error a few times, then disable the service. """ with self.error_count_lock: try: self.error_count[service_name].append(time.time()) except KeyError: self.error_count[service_name] = [time.time()] self.error_count[service_name] = [ _t for _t in self.error_count[service_name] if _t >= time.time() - ERROR_EXPIRY_TIME ] if len(self.error_count[service_name]) >= MAXIMUM_SERVICE_ERRORS: self.log.warning( f"Scaler has encountered too many errors trying to load {service_name}. " "The service will be permanently disabled...") if self.datastore.service_delta.update(service_name, [ (self.datastore.service_delta.UPDATE_SET, 'enabled', False) ]): # Raise awareness to other components by sending an event for the service self.service_event_sender.send(service_name, { 'operation': Operation.Modified, 'name': service_name }) del self.error_count[service_name] def sync_metrics(self): """Check if there are any pub-sub messages we need.""" while self.sleep(METRIC_SYNC_INTERVAL): with apm_span(self.apm_client, 'sync_metrics'): # Pull service metrics from redis service_data = self.status_table.items() for host, (service, state, time_limit) in service_data.items(): # If an entry hasn't expired, take it into account if time.time() < time_limit: self.state.update( service=service, host=host, throughput=0, busy_seconds=METRIC_SYNC_INTERVAL if state == ServiceStatus.Running else 0) # If an entry expired a while ago, the host is probably not in use any more if time.time() > time_limit + 600: self.status_table.pop(host) # Download the current targets in the orchestrator while not holding the lock with self.profiles_lock: targets = { name: profile.target_instances for name, profile in self.profiles.items() } # Check the set of services that might be sitting at zero instances, and if it is, we need to # manually check if it is offline export_interval = self.config.core.metrics.export_interval with self.profiles_lock: queues = [ profile.queue for profile in self.profiles.values() if profile.queue ] lengths_list = pq_length(*queues) lengths = {_q: _l for _q, _l in zip(queues, lengths_list)} for profile_name, profile in self.profiles.items(): queue_length = lengths.get(profile.queue, 0) # Pull out statistics from the metrics regularization update = self.state.read(profile_name) if update: delta = time.time() - profile.last_update profile.update(delta=delta, backlog=queue_length, **update) # Check if we expect no messages, if so pull the queue length ourselves # since there is no heartbeat if targets.get( profile_name ) == 0 and profile.desired_instances == 0 and profile.queue: if queue_length > 0: self.log.info( f"Service at zero instances has messages: " f"{profile.name} ({queue_length} in queue)" ) profile.update(delta=export_interval, instances=0, backlog=queue_length, duty_cycle=profile.high_duty_cycle) def _timeout_kill(self, service, container): with apm_span(self.apm_client, 'timeout_kill'): self.controller.stop_container(service, container) self.status_table.pop(container) def process_timeouts(self): with concurrent.futures.ThreadPoolExecutor(10) as pool: futures = [] while self.running: message = self.scaler_timeout_queue.pop(blocking=True, timeout=1) if not message: continue with apm_span(self.apm_client, 'process_timeouts'): # Process new messages self.log.info( f"Killing service container: {message['container']} running: {message['service']}" ) futures.append( pool.submit(self._timeout_kill, message['service'], message['container'])) # Process finished finished = [_f for _f in futures if _f.done()] futures = [_f for _f in futures if _f not in finished] for _f in finished: exception = _f.exception() if exception is not None: self.log.error( f"Exception trying to stop timed out service container: {exception}" ) def export_metrics(self): while self.sleep(self.config.logging.export_interval): with apm_span(self.apm_client, 'export_metrics'): service_metrics = {} with self.profiles_lock: for service_name, profile in self.profiles.items(): service_metrics[service_name] = { 'running': profile.running_instances, 'target': profile.target_instances, 'minimum': profile.min_instances, 'maximum': profile.instance_limit, 'dynamic_maximum': profile.max_instances, 'queue': profile.queue_length, 'duty_cycle': profile.duty_cycle, 'pressure': profile.pressure } for service_name, metrics in service_metrics.items(): export_metrics_once(service_name, Status, metrics, host=HOSTNAME, counter_type='scaler_status', config=self.config, redis=self.redis) memory, memory_total = self.controller.memory_info() cpu, cpu_total = self.controller.cpu_info() metrics = { 'memory_total': memory_total, 'cpu_total': cpu_total, 'memory_free': memory, 'cpu_free': cpu } export_metrics_once('scaler', Metrics, metrics, host=HOSTNAME, counter_type='scaler', config=self.config, redis=self.redis) def log_container_events(self): """The service status table may have references to containers that have crashed. Try to remove them all.""" while self.sleep(CONTAINER_EVENTS_LOG_INTERVAL): with apm_span(self.apm_client, 'log_container_events'): for message in self.controller.new_events(): self.log.warning("Container Event :: " + message)