def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None): self.config = forge.get_config() self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ds = datastore or forge.get_datastore(self.config) self.log = logger or logging.getLogger("assemblyline.dispatching.client") self.results = self.ds.result self.errors = self.ds.error self.files = self.ds.file self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist) self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) self.service_data = cast(Dict[str, Service], CachedObject(self._get_services)) self.dispatcher_data = [] self.dispatcher_data_age = 0.0 self.dead_dispatchers = []
def __init__(self, working_dir, worker_count=50, spawn_workers=True, use_threading=False, logger=None): self.working_dir = working_dir self.datastore = forge.get_datastore(archive_access=True) self.logger = logger self.plist = [] self.use_threading = use_threading self.instance_id = get_random_id() self.worker_queue = NamedQueue(f"r-worker-{self.instance_id}", ttl=1800) self.done_queue = NamedQueue(f"r-done-{self.instance_id}", ttl=1800) self.hash_queue = Hash(f"r-hash-{self.instance_id}") self.bucket_error = [] self.VALID_BUCKETS = sorted(list( self.datastore.ds.get_models().keys())) self.worker_count = worker_count self.spawn_workers = spawn_workers self.total_count = 0 self.error_map_count = {} self.missing_map_count = {} self.map_count = {} self.last_time = 0 self.last_count = 0 self.error_count = 0
def __init__(self, prefix="counter", host=None, port=None, track_counters=False): self.c = get_client(host, port, False) self.prefix = prefix if track_counters: self.tracker = Hash("c-tracker-%s" % prefix, host=host, port=port) else: self.tracker = None
def backup_worker(worker_id, instance_id, working_dir): datastore = forge.get_datastore(archive_access=True) worker_queue = NamedQueue(f"r-worker-{instance_id}", ttl=1800) done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800) hash_queue = Hash(f"r-hash-{instance_id}") stopping = False with open(os.path.join(working_dir, "backup.part%s" % worker_id), "w+") as backup_file: while True: data = worker_queue.pop(timeout=1) if data is None: if stopping: break continue if data.get('stop', False): if not stopping: stopping = True else: time.sleep(round(random.uniform(0.050, 0.250), 3)) worker_queue.push(data) continue missing = False success = True try: to_write = datastore.get_collection(data['bucket_name']).get( data['key'], as_obj=False) if to_write: if data.get('follow_keys', False): for bucket, bucket_key, getter in FOLLOW_KEYS.get( data['bucket_name'], []): for key in getter(to_write.get(bucket_key, None)): hash_key = "%s_%s" % (bucket, key) if not hash_queue.exists(hash_key): hash_queue.add(hash_key, "True") worker_queue.push({ "bucket_name": bucket, "key": key, "follow_keys": True }) backup_file.write( json.dumps((data['bucket_name'], data['key'], to_write)) + "\n") else: missing = True except Exception: success = False done_queue.push({ "success": success, "missing": missing, "bucket_name": data['bucket_name'], "key": data['key'] }) done_queue.push({"stopped": True})
def __init__(self, sender, log, config=None, redis=None): self.sender = sender self.log = log self.config = config or forge.get_config() self.datastore = forge.get_datastore(self.config) self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.status_queue = CommsQueue(STATUS_QUEUE, self.redis) self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH, self.redis_persist) self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ingest_scanning = Hash('m-scanning-table', self.redis_persist) self.ingest_unique_queue = PriorityQueue('m-unique', self.redis_persist) self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) constants = forge.get_constants(self.config) self.c_rng = constants.PRIORITY_RANGES['critical'] self.h_rng = constants.PRIORITY_RANGES['high'] self.m_rng = constants.PRIORITY_RANGES['medium'] self.l_rng = constants.PRIORITY_RANGES['low'] self.c_s_at = self.config.core.ingester.sampling_at['critical'] self.h_s_at = self.config.core.ingester.sampling_at['high'] self.m_s_at = self.config.core.ingester.sampling_at['medium'] self.l_s_at = self.config.core.ingester.sampling_at['low'] self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS} if self.config.core.expiry.batch_delete: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}/DAY]" else: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}]" self.scheduler = BackgroundScheduler(daemon=True) self.scheduler.add_job( self._reload_expiry_queues, 'interval', seconds=self.config.core.metrics.export_interval * 4) self.scheduler.start()
def __init__(self, logger: logging.Logger = None, shutdown_timeout: float = None, config: Config = None, datastore: AssemblylineDatastore = None, redis: RedisType = None, redis_persist: RedisType = None, default_pattern=".*"): self.updater_type = os.environ['SERVICE_PATH'].split('.')[-1].lower() self.default_pattern = default_pattern if not logger: al_log.init_logging(f'updater.{self.updater_type}', log_level=os.environ.get('LOG_LEVEL', "WARNING")) logger = logging.getLogger(f'assemblyline.updater.{self.updater_type}') super().__init__(f'assemblyline.{SERVICE_NAME}_updater', logger=logger, shutdown_timeout=shutdown_timeout, config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.update_data_hash = Hash(f'service-updates-{SERVICE_NAME}', self.redis_persist) self._update_dir = None self._update_tar = None self._time_keeper = None self._service: Optional[Service] = None self.event_sender = EventSender('changes.services', host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port) self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register(f'changes.services.{SERVICE_NAME}', self._handle_service_change_event) self.signature_change_watcher = EventWatcher(self.redis, deserializer=SignatureChange.deserialize) self.signature_change_watcher.register(f'changes.signatures.{SERVICE_NAME.lower()}', self._handle_signature_change_event) # A event flag that gets set when an update should be run for # reasons other than it being the regular interval (eg, change in signatures) self.source_update_flag = threading.Event() self.local_update_flag = threading.Event() self.local_update_start = threading.Event() # Load threads self._internal_server = None self.expected_threads = { 'Sync Service Settings': self._sync_settings, 'Outward HTTP Server': self._run_http, 'Internal HTTP Server': self._run_internal_http, 'Run source updates': self._run_source_updates, 'Run local updates': self._run_local_updates, } # Only used by updater with 'generates_signatures: false' self.latest_updates_dir = os.path.join(UPDATER_DIR, 'latest_updates') if not os.path.exists(self.latest_updates_dir): os.makedirs(self.latest_updates_dir)
def __init__(self, working_dir: str, worker_count: int = 50, spawn_workers: bool = True, use_threading: bool = False, logger: logging.Logger = None): self.working_dir = working_dir self.datastore = forge.get_datastore(archive_access=True) self.logger = logger self.plist: list[Process] = [] self.use_threading = use_threading self.instance_id = get_random_id() self.worker_queue: NamedQueue[dict[str, Any]] = NamedQueue( f"r-worker-{self.instance_id}", ttl=1800) self.done_queue: NamedQueue[dict[str, Any]] = NamedQueue( f"r-done-{self.instance_id}", ttl=1800) self.hash_queue: Hash[str] = Hash(f"r-hash-{self.instance_id}") self.bucket_error: list[str] = [] self.valid_buckets: list[str] = sorted( list(self.datastore.ds.get_models().keys())) self.worker_count = worker_count self.spawn_workers = spawn_workers self.total_count = 0 self.error_map_count: dict[str, int] = {} self.missing_map_count: dict[str, int] = {} self.map_count: dict[str, int] = {} self.last_time: float = 0 self.last_count = 0 self.error_count = 0
def _cleanup_submission(self, task: SubmissionTask, file_list: List[str]): """Clean up code that is the same for canceled and finished submissions""" submission = task.submission sid = submission.sid # Erase the temporary data which may have accumulated during processing for file_hash in file_list: hash_name = get_temporary_submission_data_name(sid, file_hash=file_hash) ExpiringHash(hash_name, host=self.redis).delete() if submission.params.quota_item and submission.params.submitter: self.log.info(f"[{sid}] Submission no longer counts toward {submission.params.submitter.upper()} quota") Hash('submissions-' + submission.params.submitter, self.redis_persist).pop(sid) if task.completed_queue: self.volatile_named_queue(task.completed_queue).push(submission.as_primitives()) # Send complete message to any watchers. watcher_list = ExpiringSet(make_watcher_list_name(sid), host=self.redis) for w in watcher_list.members(): NamedQueue(w).push(WatchQueueMessage({'status': 'STOP'}).as_primitives()) # Clear the timeout watcher watcher_list.delete() self.timeout_watcher.clear(sid) self.active_submissions.pop(sid) # Count the submission as 'complete' either way self.counter.increment('submissions_completed')
def test_hash(redis_connection): if redis_connection: from assemblyline.remote.datatypes.hash import Hash with Hash('test-hashmap') as h: assert h.add("key", "value") == 1 assert h.exists("key") == 1 assert h.get("key") == "value" assert h.set("key", "new-value") == 0 assert h.keys() == ["key"] assert h.length() == 1 assert h.items() == {"key": "new-value"} assert h.pop("key") == "new-value" assert h.length() == 0 # Make sure we can limit the size of a hash table assert h.limited_add("a", 1, 2) == 1 assert h.limited_add("a", 1, 2) == 0 assert h.length() == 1 assert h.limited_add("b", 10, 2) == 1 assert h.length() == 2 assert h.limited_add("c", 1, 2) is None assert h.length() == 2 assert h.pop("a") # Can we increment integer values in the hash assert h.increment("a") == 1 assert h.increment("a") == 2 assert h.increment("a", 10) == 12 assert h.increment("a", -22) == -10
class Counters(object): def __init__(self, prefix="counter", host=None, port=None, track_counters=False): self.c = get_client(host, port, False) self.prefix = prefix if track_counters: self.tracker = Hash("c-tracker-%s" % prefix, host=host, port=port) else: self.tracker = None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.delete() def inc(self, name, value=1, track_id=None): if self.tracker: self.tracker.add(track_id or name, now_as_iso()) return retry_call(self.c.incr, "%s-%s" % (self.prefix, name), value) def dec(self, name, value=1, track_id=None): if self.tracker: self.tracker.pop(str(track_id or name)) return retry_call(self.c.decr, "%s-%s" % (self.prefix, name), value) def get_queues_sizes(self): out = {} for queue in retry_call(self.c.keys, "%s-*" % self.prefix): queue_size = int(retry_call(self.c.get, queue)) out[queue] = queue_size return {k.decode('utf-8'): v for k, v in out.items()} def get_queues(self): return [k.decode('utf-8') for k in retry_call(self.c.keys, "%s-*" % self.prefix)] def ready(self): try: self.c.ping() except ConnectionError: return False return True def reset_queues(self): if self.tracker: self.tracker.delete() for queue in retry_call(self.c.keys, "%s-*" % self.prefix): retry_call(self.c.set, queue, "0") def delete(self): if self.tracker: self.tracker.delete() for queue in retry_call(self.c.keys, "%s-*" % self.prefix): retry_call(self.c.delete, queue)
def check_submission_quota(user, num=1) -> Optional[str]: quota_user = user['uname'] quota = user.get('submission_quota', 5) count = num + Hash('submissions-' + quota_user, **persistent).length() if count > quota: LOGGER.info("User %s exceeded their submission quota. [%s/%s]", quota_user, count, quota) return "You've exceeded your maximum submission quota of %s " % quota return None
def get_statistics_cache(config=None, redis=None): from assemblyline.remote.datatypes import get_client from assemblyline.remote.datatypes.hash import Hash if not redis: if not config: config = get_config() redis = get_client(config.core.redis.persistent.host, config.core.redis.persistent.port, False) return Hash("cached_statistics", redis)
def __init__(self, redis_persist=None, redis=None, logger=None, datastore=None): super().__init__('assemblyline.service.updater', logger=logger, datastore=datastore, redis_persist=redis_persist, redis=redis) self.container_update: Hash[dict[str, Any]] = Hash('container-update', self.redis_persist) self.latest_service_tags: Hash[dict[str, str]] = Hash('service-tags', self.redis_persist) self.service_events = EventSender('changes.services', host=self.redis) self.incompatible_services = set() self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register('changes.services.*', self._handle_service_change_event) if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE: extra_labels = {} if self.config.core.scaler.additional_labels: extra_labels = {k: v for k, v in (_l.split("=") for _l in self.config.core.scaler.additional_labels)} self.controller = KubernetesUpdateInterface(prefix='alsvc_', namespace=NAMESPACE, priority_class='al-core-priority', extra_labels=extra_labels, log_level=self.config.logging.log_level) else: self.controller = DockerUpdateInterface(log_level=self.config.logging.log_level)
def remove_service(servicename, **_): """ Remove a service configuration Variables: servicename => Name of the service to remove Arguments: None Data Block: None Result example: {"success": true} # Has the deletion succeeded """ svc = STORAGE.service_delta.get(servicename) if svc: success = True if not STORAGE.service_delta.delete(servicename): success = False if not STORAGE.service.delete_by_query(f"id:{servicename}*"): success = False STORAGE.heuristic.delete_by_query(f"id:{servicename.upper()}*") STORAGE.signature.delete_by_query(f"type:{servicename.lower()}*") # Notify components watching for service config changes event_sender.send(servicename, { 'operation': Operation.Removed, 'name': servicename }) # Clear potentially unused keys from Redis related to service Hash( f'service-updates-{servicename}', get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, )).delete() return make_api_response({"success": success}) else: return make_api_response({"success": False}, err=f"Service {servicename} does not exist", status_code=404)
def test_hash(redis_connection): if redis_connection: from assemblyline.remote.datatypes.hash import Hash with Hash('test-hashmap') as h: assert h.add("key", "value") == 1 assert h.exists("key") == 1 assert h.get("key") == "value" assert h.set("key", "new-value") == 0 assert h.keys() == ["key"] assert h.length() == 1 assert h.items() == {"key": "new-value"} assert h.pop("key") == "new-value" assert h.length() == 0 assert h.add("key", "value") == 1 assert h.conditional_remove("key", "value1") is False assert h.conditional_remove("key", "value") is True assert h.length() == 0 # Make sure we can limit the size of a hash table assert h.limited_add("a", 1, 2) == 1 assert h.limited_add("a", 1, 2) == 0 assert h.length() == 1 assert h.limited_add("b", 10, 2) == 1 assert h.length() == 2 assert h.limited_add("c", 1, 2) is None assert h.length() == 2 assert h.pop("a") # Can we increment integer values in the hash assert h.increment("a") == 1 assert h.increment("a") == 2 assert h.increment("a", 10) == 12 assert h.increment("a", -22) == -10 h.delete() # Load a bunch of items and test iteration data_before = [ ''.join(_x) for _x in itertools.product('abcde', repeat=5) ] data_before = {_x: _x + _x for _x in data_before} h.multi_set(data_before) data_after = {} for key, value in h: data_after[key] = value assert data_before == data_after
def _reset_service_updates(signature_type): service_updates = Hash( 'service-updates', get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, )) for svc in service_updates.items(): if svc.lower() == signature_type.lower(): update_data = service_updates.get(svc) update_data['next_update'] = now_as_iso(120) update_data['previous_update'] = now_as_iso(-10**10) service_updates.set(svc, update_data) break
def __init__(self, redis_persist=None, redis=None, logger=None, datastore=None): super().__init__('assemblyline.service.updater', logger=logger, datastore=datastore, redis_persist=redis_persist, redis=redis) if not FILE_UPDATE_DIRECTORY: raise RuntimeError( "The updater process must be run within the orchestration environment, " "the update volume must be mounted, and the path to the volume must be " "set in the environment variable FILE_UPDATE_DIRECTORY. Setting " "FILE_UPDATE_DIRECTORY directly may be done for testing.") # The directory where we want working temporary directories to be created. # Building our temporary directories in the persistent update volume may # have some performance down sides, but may help us run into fewer docker FS overlay # cleanup issues. Try to flush it out every time we start. This service should # be a singleton anyway. self.temporary_directory = os.path.join(FILE_UPDATE_DIRECTORY, '.tmp') shutil.rmtree(self.temporary_directory, ignore_errors=True) os.makedirs(self.temporary_directory) self.container_update = Hash('container-update', self.redis_persist) self.services = Hash('service-updates', self.redis_persist) self.latest_service_tags = Hash('service-tags', self.redis_persist) self.running_updates: Dict[str, Thread] = {} # Prepare a single threaded scheduler self.scheduler = sched.scheduler() # if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE: self.controller = KubernetesUpdateInterface( prefix='alsvc_', namespace=NAMESPACE, priority_class='al-core-priority') else: self.controller = DockerUpdateInterface()
class HeartbeatFormatter(object): def __init__(self, sender, log, config=None, redis=None): self.sender = sender self.log = log self.config = config or forge.get_config() self.datastore = forge.get_datastore(self.config) self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.status_queue = CommsQueue(STATUS_QUEUE, self.redis) self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH, self.redis_persist) self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ingest_scanning = Hash('m-scanning-table', self.redis_persist) self.ingest_unique_queue = PriorityQueue('m-unique', self.redis_persist) self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) constants = forge.get_constants(self.config) self.c_rng = constants.PRIORITY_RANGES['critical'] self.h_rng = constants.PRIORITY_RANGES['high'] self.m_rng = constants.PRIORITY_RANGES['medium'] self.l_rng = constants.PRIORITY_RANGES['low'] self.c_s_at = self.config.core.ingester.sampling_at['critical'] self.h_s_at = self.config.core.ingester.sampling_at['high'] self.m_s_at = self.config.core.ingester.sampling_at['medium'] self.l_s_at = self.config.core.ingester.sampling_at['low'] self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS} if self.config.core.expiry.batch_delete: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}/DAY]" else: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}]" self.scheduler = BackgroundScheduler(daemon=True) self.scheduler.add_job( self._reload_expiry_queues, 'interval', seconds=self.config.core.metrics.export_interval * 4) self.scheduler.start() def _reload_expiry_queues(self): try: self.log.info("Refreshing expiry queues...") for collection_name in metrics.EXPIRY_METRICS: try: collection = getattr(self.datastore, collection_name) self.to_expire[collection_name] = collection.search( self.delete_query, rows=0, fl='id', track_total_hits="true")['total'] except SearchException: self.to_expire[collection_name] = 0 except Exception: self.log.exception( "Unknown exception occurred while reloading expiry queues:") def send_heartbeat(self, m_type, m_name, m_data, instances): if m_type == "dispatcher": try: instances = sorted(Dispatcher.all_instances( self.redis_persist)) inflight = { _i: Dispatcher.instance_assignment_size( self.redis_persist, _i) for _i in instances } queues = { _i: Dispatcher.all_queue_lengths(self.redis, _i) for _i in instances } msg = { "sender": self.sender, "msg": { "inflight": { "max": self.config.core.dispatcher.max_inflight, "outstanding": self.dispatch_active_hash.length(), "per_instance": [inflight[_i] for _i in instances] }, "instances": len(instances), "metrics": m_data, "queues": { "ingest": self.dispatcher_submission_queue.length(), "start": [queues[_i]['start'] for _i in instances], "result": [queues[_i]['result'] for _i in instances], "command": [queues[_i]['command'] for _i in instances] }, "component": m_name, } } self.status_queue.publish( DispatcherMessage(msg).as_primitives()) self.log.info(f"Sent dispatcher heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating DispatcherMessage") elif m_type == "ingester": try: c_q_len = self.ingest_unique_queue.count(*self.c_rng) h_q_len = self.ingest_unique_queue.count(*self.h_rng) m_q_len = self.ingest_unique_queue.count(*self.m_rng) l_q_len = self.ingest_unique_queue.count(*self.l_rng) msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "processing": { "inflight": self.ingest_scanning.length() }, "processing_chance": { "critical": 1 - drop_chance(c_q_len, self.c_s_at), "high": 1 - drop_chance(h_q_len, self.h_s_at), "low": 1 - drop_chance(l_q_len, self.l_s_at), "medium": 1 - drop_chance(m_q_len, self.m_s_at) }, "queues": { "critical": c_q_len, "high": h_q_len, "ingest": self.ingest_queue.length(), "complete": self.ingest_complete_queue.length(), "low": l_q_len, "medium": m_q_len } } } self.status_queue.publish(IngestMessage(msg).as_primitives()) self.log.info(f"Sent ingester heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating IngestMessage") elif m_type == "alerter": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "queues": { "alert": self.alert_queue.length() } } } self.status_queue.publish(AlerterMessage(msg).as_primitives()) self.log.info(f"Sent alerter heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating AlerterMessage") elif m_type == "expiry": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "queues": self.to_expire } } self.status_queue.publish(ExpiryMessage(msg).as_primitives()) self.log.info(f"Sent expiry heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ExpiryMessage") elif m_type == "archive": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data } } self.status_queue.publish(ArchiveMessage(msg).as_primitives()) self.log.info(f"Sent archive heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ArchiveMessage") elif m_type == "scaler": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, } } self.status_queue.publish(ScalerMessage(msg).as_primitives()) self.log.info(f"Sent scaler heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating WatcherMessage") elif m_type == "scaler_status": try: msg = { "sender": self.sender, "msg": { "service_name": m_name, "metrics": m_data, } } self.status_queue.publish( ScalerStatusMessage(msg).as_primitives()) self.log.info(f"Sent scaler status heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating WatcherMessage") elif m_type == "service": try: busy, idle = get_working_and_idle(self.redis, m_name) msg = { "sender": self.sender, "msg": { "instances": len(busy) + len(idle), "metrics": m_data, "activity": { 'busy': len(busy), 'idle': len(idle) }, "queue": get_service_queue(m_name, self.redis).length(), "service_name": m_name } } self.status_queue.publish(ServiceMessage(msg).as_primitives()) self.log.info(f"Sent service heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ServiceMessage") else: self.log.warning( f"Skipping unknown counter: {m_name} [{m_type}] ==> {m_data}")
def get_service_stage_hash(redis) -> Hash[int]: """A hash from service name to ServiceStage enum values.""" return Hash('service-stage', redis)
class DistributedBackup(object): def __init__(self, working_dir, worker_count=50, spawn_workers=True, use_threading=False, logger=None): self.working_dir = working_dir self.datastore = forge.get_datastore(archive_access=True) self.logger = logger self.plist = [] self.use_threading = use_threading self.instance_id = get_random_id() self.worker_queue = NamedQueue(f"r-worker-{self.instance_id}", ttl=1800) self.done_queue = NamedQueue(f"r-done-{self.instance_id}", ttl=1800) self.hash_queue = Hash(f"r-hash-{self.instance_id}") self.bucket_error = [] self.VALID_BUCKETS = sorted(list( self.datastore.ds.get_models().keys())) self.worker_count = worker_count self.spawn_workers = spawn_workers self.total_count = 0 self.error_map_count = {} self.missing_map_count = {} self.map_count = {} self.last_time = 0 self.last_count = 0 self.error_count = 0 def cleanup(self): self.worker_queue.delete() self.done_queue.delete() self.hash_queue.delete() for p in self.plist: p.terminate() def done_thread(self, title): t0 = time.time() self.last_time = t0 running_threads = self.worker_count while running_threads > 0: msg = self.done_queue.pop(timeout=1) if msg is None: continue if "stopped" in msg: running_threads -= 1 continue bucket_name = msg.get('bucket_name', 'unknown') if msg.get('success', False): self.total_count += 1 if msg.get("missing", False): if bucket_name not in self.missing_map_count: self.missing_map_count[bucket_name] = 0 self.missing_map_count[bucket_name] += 1 else: if bucket_name not in self.map_count: self.map_count[bucket_name] = 0 self.map_count[bucket_name] += 1 new_t = time.time() if (new_t - self.last_time) > 5: if self.logger: self.logger.info( "%s (%s at %s keys/sec) ==> %s" % (self.total_count, new_t - self.last_time, int((self.total_count - self.last_count) / (new_t - self.last_time)), self.map_count)) self.last_count = self.total_count self.last_time = new_t else: self.error_count += 1 if bucket_name not in self.error_map_count: self.error_map_count[bucket_name] = 0 self.error_map_count[bucket_name] += 1 # Cleanup self.cleanup() summary = "" summary += "\n########################\n" summary += "####### SUMMARY #######\n" summary += "########################\n" summary += "%s items - %s errors - %s secs\n\n" % \ (self.total_count, self.error_count, time.time() - t0) for k, v in self.map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.missing_map_count.keys()) > 0: summary += "\n\nMissing data:\n\n" for k, v in self.missing_map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.error_map_count.keys()) > 0: summary += "\n\nErrors:\n\n" for k, v in self.error_map_count.items(): summary += "\t%15s: %s\n" % (k.upper(), v) if len(self.bucket_error) > 0: summary += f"\nThese buckets failed to {title.lower()} completely: {self.bucket_error}\n" if self.logger: self.logger.info(summary) # noinspection PyBroadException,PyProtectedMember def backup(self, bucket_list, follow_keys=False, query=None): if query is None: query = 'id:*' for bucket in bucket_list: if bucket not in self.VALID_BUCKETS: if self.logger: self.logger.warn( "\n%s is not a valid bucket.\n\n" "The list of valid buckets is the following:\n\n\t%s\n" % (bucket.upper(), "\n\t".join(self.VALID_BUCKETS))) return targets = ', '.join(bucket_list) try: if self.logger: self.logger.info("\n-----------------------") self.logger.info("----- Data Backup -----") self.logger.info("-----------------------") self.logger.info(f" Deep: {follow_keys}") self.logger.info(f" Buckets: {targets}") self.logger.info(f" Workers: {self.worker_count}") self.logger.info(f" Target directory: {self.working_dir}") self.logger.info(f" Filtering query: {query}") # Start the workers for x in range(self.worker_count): if self.use_threading: t = threading.Thread(target=backup_worker, args=(x, self.instance_id, self.working_dir)) t.setDaemon(True) t.start() else: p = Process(target=backup_worker, args=(x, self.instance_id, self.working_dir)) p.start() self.plist.append(p) # Start done thread dt = threading.Thread(target=self.done_thread, args=('Backup', ), name="Done thread") dt.setDaemon(True) dt.start() # Process data buckets for bucket_name in bucket_list: try: collection = self.datastore.get_collection(bucket_name) for item in collection.stream_search(query, fl="id", item_buffer_size=500, as_obj=False): self.worker_queue.push({ "bucket_name": bucket_name, "key": item['id'], "follow_keys": follow_keys }) except Exception as e: self.cleanup() if self.logger: self.logger.execption(e) self.logger.error( "Error occurred while processing bucket %s." % bucket_name) self.bucket_error.append(bucket_name) for _ in range(self.worker_count): self.worker_queue.push({"stop": True}) dt.join() except Exception as e: if self.logger: self.logger.execption(e) def restore(self): try: if self.logger: self.logger.info("\n------------------------") self.logger.info("----- Data Restore -----") self.logger.info("------------------------") self.logger.info(f" Workers: {self.worker_count}") self.logger.info(f" Target directory: {self.working_dir}") for x in range(self.worker_count): if self.use_threading: t = threading.Thread(target=restore_worker, args=(x, self.instance_id, self.working_dir)) t.setDaemon(True) t.start() else: p = Process(target=restore_worker, args=(x, self.instance_id, self.working_dir)) p.start() self.plist.append(p) # Start done thread dt = threading.Thread(target=self.done_thread, args=('Restore', ), name="Done thread") dt.setDaemon(True) dt.start() # Wait for workers to finish dt.join() except Exception as e: if self.logger: self.logger.execption(e)
class ServiceUpdater(ThreadedCoreBase): def __init__(self, logger: logging.Logger = None, shutdown_timeout: float = None, config: Config = None, datastore: AssemblylineDatastore = None, redis: RedisType = None, redis_persist: RedisType = None, default_pattern=".*"): self.updater_type = os.environ['SERVICE_PATH'].split('.')[-1].lower() self.default_pattern = default_pattern if not logger: al_log.init_logging(f'updater.{self.updater_type}', log_level=os.environ.get('LOG_LEVEL', "WARNING")) logger = logging.getLogger(f'assemblyline.updater.{self.updater_type}') super().__init__(f'assemblyline.{SERVICE_NAME}_updater', logger=logger, shutdown_timeout=shutdown_timeout, config=config, datastore=datastore, redis=redis, redis_persist=redis_persist) self.update_data_hash = Hash(f'service-updates-{SERVICE_NAME}', self.redis_persist) self._update_dir = None self._update_tar = None self._time_keeper = None self._service: Optional[Service] = None self.event_sender = EventSender('changes.services', host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port) self.service_change_watcher = EventWatcher(self.redis, deserializer=ServiceChange.deserialize) self.service_change_watcher.register(f'changes.services.{SERVICE_NAME}', self._handle_service_change_event) self.signature_change_watcher = EventWatcher(self.redis, deserializer=SignatureChange.deserialize) self.signature_change_watcher.register(f'changes.signatures.{SERVICE_NAME.lower()}', self._handle_signature_change_event) # A event flag that gets set when an update should be run for # reasons other than it being the regular interval (eg, change in signatures) self.source_update_flag = threading.Event() self.local_update_flag = threading.Event() self.local_update_start = threading.Event() # Load threads self._internal_server = None self.expected_threads = { 'Sync Service Settings': self._sync_settings, 'Outward HTTP Server': self._run_http, 'Internal HTTP Server': self._run_internal_http, 'Run source updates': self._run_source_updates, 'Run local updates': self._run_local_updates, } # Only used by updater with 'generates_signatures: false' self.latest_updates_dir = os.path.join(UPDATER_DIR, 'latest_updates') if not os.path.exists(self.latest_updates_dir): os.makedirs(self.latest_updates_dir) def trigger_update(self): self.source_update_flag.set() def update_directory(self): return self._update_dir def update_tar(self): return self._update_tar def get_active_config_hash(self) -> int: return self.update_data_hash.get(CONFIG_HASH_KEY) or 0 def set_active_config_hash(self, config_hash: int): self.update_data_hash.set(CONFIG_HASH_KEY, config_hash) def get_source_update_time(self) -> float: return self.update_data_hash.get(SOURCE_UPDATE_TIME_KEY) or 0 def set_source_update_time(self, update_time: float): self.update_data_hash.set(SOURCE_UPDATE_TIME_KEY, update_time) def get_source_extra(self) -> dict[str, Any]: return self.update_data_hash.get(SOURCE_EXTRA_KEY) or {} def set_source_extra(self, extra_data: dict[str, Any]): self.update_data_hash.set(SOURCE_EXTRA_KEY, extra_data) def get_local_update_time(self) -> float: if self._time_keeper: return os.path.getctime(self._time_keeper) return 0 def status(self): return { 'local_update_time': self.get_local_update_time(), 'download_available': self._update_dir is not None, '_directory': self._update_dir, '_tar': self._update_tar, } def stop(self): super().stop() self.signature_change_watcher.stop() self.service_change_watcher.stop() self.source_update_flag.set() self.local_update_flag.set() self.local_update_start.set() if self._internal_server: self._internal_server.shutdown() def try_run(self): self.signature_change_watcher.start() self.service_change_watcher.start() self.maintain_threads(self.expected_threads) def _run_internal_http(self): """run backend insecure http server A small inprocess server to syncronize info between gunicorn and the updater daemon. This HTTP server is not safe for exposing externally, but fine for IPC. """ them = self class Handler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() self.wfile.write(json.dumps(them.status()).encode()) def log_error(self, format: str, *args: Any): them.log.info(format % args) def log_message(self, format: str, *args: Any): them.log.debug(format % args) self._internal_server = ThreadingHTTPServer(('0.0.0.0', 9999), Handler) self._internal_server.serve_forever() def _run_http(self): # Start a server for our http interface in a separate process my_env = os.environ.copy() proc = subprocess.Popen(["gunicorn", "assemblyline_v4_service.updater.app:app", "--config=python:assemblyline_v4_service.updater.gunicorn_config"], env=my_env) while self.sleep(1): if proc.poll() is not None: break # If we have left the loop and the process is still alive, stop it. if proc.poll() is not None: proc.terminate() proc.wait() @staticmethod def config_hash(service: Service) -> int: if service is None: return 0 return hash(json.dumps(service.update_config.as_primitives())) def _handle_signature_change_event(self, data: SignatureChange): self.local_update_flag.set() def _handle_service_change_event(self, data: ServiceChange): if data.operation == Operation.Modified: self._pull_settings() def _sync_settings(self): # Download the service object from datastore self._service = self.datastore.get_service_with_delta(SERVICE_NAME) while self.sleep(SERVICE_PULL_INTERVAL): self._pull_settings() def _pull_settings(self): # Download the service object from datastore self._service = self.datastore.get_service_with_delta(SERVICE_NAME) # If the update configuration for the service has changed, trigger an update if self.config_hash(self._service) != self.get_active_config_hash(): self.source_update_flag.set() def do_local_update(self) -> None: old_update_time = self.get_local_update_time() if not os.path.exists(UPDATER_DIR): os.makedirs(UPDATER_DIR) _, time_keeper = tempfile.mkstemp(prefix="time_keeper_", dir=UPDATER_DIR) if self._service.update_config.generates_signatures: output_directory = tempfile.mkdtemp(prefix="update_dir_", dir=UPDATER_DIR) self.log.info("Setup service account.") username = self.ensure_service_account() self.log.info("Create temporary API key.") with temporary_api_key(self.datastore, username) as api_key: self.log.info(f"Connecting to Assemblyline API: {UI_SERVER}") al_client = get_client(UI_SERVER, apikey=(username, api_key), verify=False) # Check if new signatures have been added self.log.info("Check for new signatures.") if al_client.signature.update_available( since=epoch_to_iso(old_update_time) or '', sig_type=self.updater_type)['update_available']: self.log.info("An update is available for download from the datastore") self.log.debug(f"{self.updater_type} update available since {epoch_to_iso(old_update_time) or ''}") extracted_zip = False attempt = 0 # Sometimes a zip file isn't always returned, will affect service's use of signature source. Patience.. while not extracted_zip and attempt < 5: temp_zip_file = os.path.join(output_directory, 'temp.zip') al_client.signature.download( output=temp_zip_file, query=f"type:{self.updater_type} AND (status:NOISY OR status:DEPLOYED)") self.log.debug(f"Downloading update to {temp_zip_file}") if os.path.exists(temp_zip_file) and os.path.getsize(temp_zip_file) > 0: self.log.debug(f"File type ({os.path.getsize(temp_zip_file)}B): {zip_ident(temp_zip_file, 'unknown')}") try: with ZipFile(temp_zip_file, 'r') as zip_f: zip_f.extractall(output_directory) extracted_zip = True self.log.info("Zip extracted.") except BadZipFile: attempt += 1 self.log.warning(f"[{attempt}/5] Bad zip. Trying again after 30s...") time.sleep(30) except Exception as e: self.log.error(f'Problem while extracting signatures to disk: {e}') break os.remove(temp_zip_file) if extracted_zip: self.log.info("New ruleset successfully downloaded and ready to use") self.serve_directory(output_directory, time_keeper) else: self.log.error("Signatures aren't saved to disk.") shutil.rmtree(output_directory, ignore_errors=True) if os.path.exists(time_keeper): os.unlink(time_keeper) else: self.log.info("No signature updates available.") shutil.rmtree(output_directory, ignore_errors=True) if os.path.exists(time_keeper): os.unlink(time_keeper) else: output_directory = self.prepare_output_directory() self.serve_directory(output_directory, time_keeper) def do_source_update(self, service: Service) -> None: self.log.info(f"Connecting to Assemblyline API: {UI_SERVER}...") run_time = time.time() username = self.ensure_service_account() with temporary_api_key(self.datastore, username) as api_key: with tempfile.TemporaryDirectory() as update_dir: al_client = get_client(UI_SERVER, apikey=(username, api_key), verify=False) old_update_time = self.get_source_update_time() self.log.info("Connected!") # Parse updater configuration previous_hashes: dict[str, dict[str, str]] = self.get_source_extra() sources: dict[str, UpdateSource] = {_s['name']: _s for _s in service.update_config.sources} files_sha256: dict[str, dict[str, str]] = {} # Go through each source and download file for source_name, source_obj in sources.items(): source = source_obj.as_primitives() uri: str = source['uri'] default_classification = source.get('default_classification', classification.UNRESTRICTED) try: # Pull sources from external locations (method depends on the URL) files = git_clone_repo(source, old_update_time, self.default_pattern, self.log, update_dir) \ if uri.endswith('.git') else url_download(source, old_update_time, self.log, update_dir) # Add to collection of sources for caching purposes self.log.info(f"Found new {self.updater_type} rule files to process for {source_name}!") validated_files = list() for file, sha256 in files: files_sha256.setdefault(source_name, {}) if previous_hashes.get(source_name, {}).get(file, None) != sha256 and self.is_valid(file): files_sha256[source_name][file] = sha256 validated_files.append((file, sha256)) # Import into Assemblyline self.import_update(validated_files, al_client, source_name, default_classification) except SkipSource: # This source hasn't changed, no need to re-import into Assemblyline self.log.info(f'No new {self.updater_type} rule files to process for {source_name}') if source_name in previous_hashes: files_sha256[source_name] = previous_hashes[source_name] continue self.set_source_update_time(run_time) self.set_source_extra(files_sha256) self.set_active_config_hash(self.config_hash(service)) self.local_update_flag.set() # Define to determine if file is a valid signature file def is_valid(self, file_path) -> bool: return True # Define how your source update gets imported into Assemblyline def import_update(self, files_sha256: List[Tuple[str, str]], client: Client, source_name: str, default_classification=None): raise NotImplementedError() # Define how to prepare the output directory before being served, must return the path of the directory to serve. def prepare_output_directory(self) -> str: output_directory = tempfile.mkdtemp() shutil.copytree(self.latest_updates_dir, output_directory, dirs_exist_ok=True) return output_directory def _run_source_updates(self): # Wait until basic data is loaded while self._service is None and self.sleep(1): pass if not self._service: return self.log.info("Service info loaded") try: self.log.info("Checking for in cluster update cache") self.do_local_update() self._service_stage_hash.set(SERVICE_NAME, ServiceStage.Running) self.event_sender.send(SERVICE_NAME, {'operation': Operation.Modified, 'name': SERVICE_NAME}) except Exception: self.log.exception('An error occurred loading cached update files. Continuing.') self.local_update_start.set() # Go into a loop running the update whenever triggered or its time to while self.running: # Stringify and hash the the current update configuration service = self._service update_interval = service.update_config.update_interval_seconds # Is it time to update yet? if time.time() - self.get_source_update_time() < update_interval and not self.source_update_flag.is_set(): self.source_update_flag.wait(60) continue if not self.running: return # With temp directory self.source_update_flag.clear() self.log.info('Calling update function...') # Run update function # noinspection PyBroadException try: self.do_source_update(service=service) except Exception: self.log.exception('An error occurred running the update. Will retry...') self.source_update_flag.set() self.sleep(60) continue def serve_directory(self, new_directory: str, new_time: str): self.log.info("Update finished with new data.") new_tar = '' try: # Tar update directory _, new_tar = tempfile.mkstemp(prefix="signatures_", dir=UPDATER_DIR, suffix='.tar.bz2') tar_handle = tarfile.open(new_tar, 'w:bz2') tar_handle.add(new_directory, '/') tar_handle.close() # swap update directory with old one self._update_dir, new_directory = new_directory, self._update_dir self._update_tar, new_tar = new_tar, self._update_tar self._time_keeper, new_time = new_time, self._time_keeper self.log.info(f"Now serving: {self._update_dir} and {self._update_tar} ({self.get_local_update_time()})") finally: if new_tar and os.path.exists(new_tar): self.log.info(f"Remove old tar file: {new_tar}") time.sleep(3) os.unlink(new_tar) if new_directory and os.path.exists(new_directory): self.log.info(f"Remove old directory: {new_directory}") shutil.rmtree(new_directory, ignore_errors=True) if new_time and os.path.exists(new_time): self.log.info(f"Remove old time keeper file: {new_time}") os.unlink(new_time) def _run_local_updates(self): # Wait until basic data is loaded while self._service is None and self.sleep(1): pass if not self._service: return self.local_update_start.wait() # Go into a loop running the update whenever triggered or its time to while self.running: # Is it time to update yet? if not self.local_update_flag.is_set(): self.local_update_flag.wait(60) continue if not self.running: return self.local_update_flag.clear() # With temp directory self.log.info('Updating local files...') # Run update function # noinspection PyBroadException try: self.do_local_update() if self._service_stage_hash.get(SERVICE_NAME) == ServiceStage.Update: self._service_stage_hash.set(SERVICE_NAME, ServiceStage.Running) self.event_sender.send(SERVICE_NAME, {'operation': Operation.Modified, 'name': SERVICE_NAME}) except Exception: self.log.exception('An error occurred finding new local files. Will retry...') self.local_update_flag.set() self.sleep(60) continue def ensure_service_account(self): """Check that the update service account exists, if it doesn't, create it.""" uname = 'update_service_account' if self.datastore.user.get_if_exists(uname): return uname user_data = User({ "agrees_with_tos": "NOW", "classification": "RESTRICTED", "name": "Update Account", "password": get_password_hash(''.join(random.choices(string.ascii_letters, k=20))), "uname": uname, "type": ["signature_importer"] }) self.datastore.user.save(uname, user_data) self.datastore.user_settings.save(uname, UserSettings()) return uname
def do_ui(self, args): """ Perform UI related operations Usage: ui show_sessions [username] ui clear_sessions [username] actions: show_sessions show all active sessions clear_sessions Removes all active sessions Parameters: username User use to filter sessions [optional] Examples: # Clear sessions for user bob ui clear_sessions bob # Show all current sessions ui show_sessions """ valid_func = ['clear_sessions', 'show_sessions'] args = self._parse_args(args) if len(args) not in [1, 2]: self._print_error("Wrong number of arguments for restore command.") return func = args[0] if func not in valid_func: self._print_error(f"Invalid action '{func}' for ui command.") return if func == 'clear_sessions': username = None if len(args) == 2: username = args[1] flsk_sess = Hash("flask_sessions", host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) if not username: flsk_sess.delete() self.logger.info("All sessions where cleared.") else: for k, v in flsk_sess.items().items(): if v.get('username', None) == username: self.logger.info(f"Removing session: {v}") flsk_sess.pop(k) self.logger.info( f"All sessions for user '{username}' removed.") if func == 'show_sessions': username = None if len(args) == 2: username = args[1] flsk_sess = Hash("flask_sessions", host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) if not username: for k, v in flsk_sess.items().items(): self.logger.info(f"{v.get('username', None)} => {v}") else: self.logger.info(f'Showing sessions for user {username}:') for k, v in flsk_sess.items().items(): if v.get('username', None) == username: self.logger.info(f" {v}")
class Ingester: """Internal interface to the ingestion queues.""" def __init__(self, datastore, logger, classification=None, redis=None, persistent_redis=None, metrics_name='ingester'): self.datastore = datastore self.log = logger # Cache the user groups self.cache_lock = threading.RLock( ) # TODO are middle man instances single threaded now? self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Create a config cache that will refresh config values periodically self.config = forge.CachedObject(forge.get_config) # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value = constants.PRIORITIES self.priority_range = constants.PRIORITY_RANGES self.threshold_value = constants.PRIORITY_THRESHOLDS # Connect to the redis servers self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.persistent_redis = persistent_redis or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.persistent_redis) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(_completeq_name, self.redis) # Internal. Dropped entries are placed on this queue. # self.drop_queue = NamedQueue('m-drop', self.persistent_redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.persistent_redis) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.persistent_redis) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.persistent_redis) # Internal, timeout watch queue self.timeout_queue = PriorityQueue('m-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.persistent_redis) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis) def get_groups_from_user(self, username: str) -> List[str]: # Reset the group cache at the top of each hour if time.time() // HOUR_IN_SECONDS > self._user_groups_reset: self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS # Get the groups for this user if not known if username not in self._user_groups: user_data = self.datastore.user.get(username) if user_data: self._user_groups[username] = user_data.groups else: self._user_groups[username] = [] return self._user_groups[username] def ingest(self, task: IngestTask): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Task received for processing" ) # Load a snapshot of ingest parameters as of right now. max_file_size = self.config.submission.max_file_size param = task.params self.counter.increment('bytes_ingested', increment_by=task.file_size) self.counter.increment('submissions_ingested') if any(len(file.sha256) != 64 for file in task.submission.files): self.log.error( f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped") self.send_notification(task, failure="Invalid sha256", logfunc=self.log.warning) return # Clean up metadata strings, since we may delete some, iterate on a copy of the keys for key in list(task.submission.metadata.keys()): value = task.submission.metadata[key] meta_size = len(value) if meta_size > self.config.submission.max_metadata_length: self.log.info( f'[{task.ingest_id} :: {task.sha256}] ' f'Removing {key} from metadata because value is too big') task.submission.metadata.pop(key) if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop: task.failure = f"File too large ({task.file_size} > {max_file_size})" self._notify_drop(task) self.counter.increment('skipped') self.log.error( f"[{task.ingest_id} :: {task.sha256}] {task.failure}") return # Set the groups from the user, if they aren't already set if not task.params.groups: task.params.groups = self.get_groups_from_user( task.params.submitter) # Check if this file is already being processed pprevious, previous, score = None, False, None if not param.ignore_cache: pprevious, previous, score, _ = self.check(task) # Assign priority. low_priority = self.is_low_priority(task) priority = param.priority if priority < 0: priority = self.priority_value['medium'] if score is not None: priority = self.priority_value['low'] for level, threshold in self.threshold_value.items(): if score >= threshold: priority = self.priority_value[level] break elif low_priority: priority = self.priority_value['low'] # Reduce the priority by an order of magnitude for very old files. current_time = now() if priority and self.expired( current_time - task.submission.time.timestamp(), 0): priority = (priority / 10) or 1 param.priority = priority # Do this after priority has been assigned. # (So we don't end up dropping the resubmission). if previous: self.counter.increment('duplicates') self.finalize(pprevious, previous, score, task) return if self.drop(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped") return if self.is_whitelisted(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted") return self.unique_queue.push(priority, task.as_primitives()) def check(self, task: IngestTask): key = self.stamp_filescore_key(task) with self.cache_lock: result = self.cache.get(key, None) if result: self.counter.increment('cache_hit_local') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Local cache hit') else: result = self.datastore.filescore.get(key) if result: self.counter.increment('cache_hit') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Remote cache hit') else: self.counter.increment('cache_miss') return None, False, None, key with self.cache_lock: self.cache[key] = result current_time = now() age = current_time - result.time errors = result.errors if self.expired(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired" ) self.counter.increment('cache_expired') self.cache.pop(key, None) self.datastore.filescore.delete(key) return None, False, None, key elif self.stale(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale" ) self.counter.increment('cache_stale') return None, False, result.score, key return result.psid, result.sid, result.score, key def stale(self, delta: float, errors: int): if errors: return delta >= self.config.core.ingester.incomplete_stale_after_seconds else: return delta >= self.config.core.ingester.stale_after_seconds @staticmethod def stamp_filescore_key(task: IngestTask, sha256=None): if not sha256: sha256 = task.submission.files[0].sha256 key = task.scan_key if not key: key = task.params.create_filescore_key(sha256) task.scan_key = key return key def completed(self, sub): """Invoked when notified that a submission has completed.""" # There is only one file in the submissions we have made sha256 = sub.files[0].sha256 scan_key = sub.params.create_filescore_key(sha256) raw = self.scanning.pop(scan_key) psid = sub.params.psid score = sub.max_score sid = sub.sid if not raw: # Some other worker has already popped the scanning queue? self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission completed twice") return scan_key task = IngestTask(raw) task.submission.sid = sid errors = sub.error_count file_count = sub.file_count self.counter.increment('submissions_completed') self.counter.increment('files_completed', increment_by=file_count) self.counter.increment('bytes_completed', increment_by=task.file_size) with self.cache_lock: fs = self.cache[scan_key] = FileScore({ 'expiry_ts': now(self.config.core.ingester.cache_dtl * 24 * 60 * 60), 'errors': errors, 'psid': psid, 'score': score, 'sid': sid, 'time': now(), }) self.datastore.filescore.save(scan_key, fs) self.finalize(psid, sid, score, task) def exhaust() -> Iterable[IngestTask]: while True: res = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if res is None: break res = IngestTask(res) res.submission.sid = sid yield res # You may be tempted to remove the assignment to dups and use the # value directly in the for loop below. That would be a mistake. # The function finalize may push on the duplicate queue which we # are pulling off and so condensing those two lines creates a # potential infinite loop. dups = [dup for dup in exhaust()] for dup in dups: self.finalize(psid, sid, score, dup) return scan_key def send_notification(self, task: IngestTask, failure=None, logfunc=None): if logfunc is None: logfunc = self.log.info if failure: task.failure = failure failure = task.failure if failure: logfunc("%s: %s", failure, str(task.json())) if not task.submission.notification.queue: return note_queue = _notification_queue_prefix + task.submission.notification.queue threshold = task.submission.notification.threshold if threshold is not None and task.score is not None and task.score < threshold: return q = self.notification_queues.get(note_queue, None) if not q: self.notification_queues[note_queue] = q = NamedQueue( note_queue, self.persistent_redis) q.push(task.as_primitives()) def expired(self, delta: float, errors) -> bool: if errors: return delta >= self.config.core.ingester.incomplete_expire_after_seconds else: return delta >= self.config.core.ingester.expire_after def drop(self, task: IngestTask) -> bool: priority = task.params.priority sample_threshold = self.config.core.ingester.sampling_at dropped = False if priority <= _min_priority: dropped = True else: for level, rng in self.priority_range.items(): if rng[0] <= priority <= rng[1] and level in sample_threshold: dropped = must_drop(self.unique_queue.count(*rng), sample_threshold[level]) break if not dropped: if task.file_size > self.config.submission.max_file_size or task.file_size == 0: dropped = True if task.params.never_drop or not dropped: return False task.failure = 'Skipped' self._notify_drop(task) self.counter.increment('skipped') return True def _notify_drop(self, task: IngestTask): self.send_notification(task) c12n = task.params.classification expiry = now_as_iso(86400) sha256 = task.submission.files[0].sha256 self.datastore.save_or_freshen_file(sha256, {'sha256': sha256}, expiry, c12n, redis=self.redis) def is_whitelisted(self, task: IngestTask): reason, hit = self.get_whitelist_verdict(self.whitelist, task) hit = {x: dotdump(safe_str(y)) for x, y in hit.items()} sha256 = task.submission.files[0].sha256 if not reason: with self.whitelisted_lock: reason = self.whitelisted.get(sha256, None) if reason: hit = 'cached' if reason: if hit != 'cached': with self.whitelisted_lock: self.whitelisted[sha256] = reason task.failure = "Whitelisting due to reason %s (%s)" % (dotdump( safe_str(reason)), hit) self._notify_drop(task) self.counter.increment('whitelisted') return reason def submit(self, task: IngestTask): self.submit_client.submit( submission_obj=task.submission, completed_queue=_completeq_name, ) self.timeout_queue.push(int(now(_max_time)), task.scan_key) self.log.info( f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis" ) def retry(self, task, scan_key, ex): current_time = now() retries = task.retries + 1 if retries > _max_retries: trace = '' if ex: trace = ': ' + get_stacktrace_info(ex) self.log.error( f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}' ) self.duplicate_queue.delete(_dup_prefix + scan_key) elif self.expired(current_time - task.ingest_time.timestamp(), 0): self.log.info( f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission' ) self.duplicate_queue.delete(_dup_prefix + scan_key) else: self.log.info( f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})' ) task.retries = retries self.retry_queue.push(int(now(_retry_delay)), task.json()) def finalize(self, psid, sid, score, task: IngestTask): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed") if psid: task.params.psid = psid task.score = score task.submission.sid = sid selected = task.params.services.selected resubmit_to = task.params.services.resubmit resubmit_selected = determine_resubmit_selected(selected, resubmit_to) will_resubmit = resubmit_selected and should_resubmit(score) if will_resubmit: task.extended_scan = 'submitted' task.params.psid = None if self.is_alert(task, score): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Notifying alerter " f"to {'update' if will_resubmit else 'create'} an alert") self.alert_queue.push(task.as_primitives()) self.send_notification(task) if will_resubmit: self.log.info( f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis" ) task.params.psid = sid task.submission.sid = None task.params.services.resubmit = [] task.scan_key = None task.params.services.selected = resubmit_selected self.unique_queue.push(task.params.priority, task.as_primitives()) def is_alert(self, task: IngestTask, score): if not task.params.generate_alert: return False if score < self.threshold_value['critical']: return False return True
TEMP_SUBMIT_DIR = "/var/lib/assemblyline/submit/" QUOTA_TRACKER = UserQuotaTracker( 'quota', timeout=60 * 2, # 2 Minutes timout host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) SUBMISSION_TRACKER = UserQuotaTracker( 'submissions', timeout=60 * 60, # 60 minutes timout host=config.core.redis.persistent.host, port=config.core.redis.persistent.port) KV_SESSION = Hash("flask_sessions", host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) @functools.lru_cache() def get_submission_traffic_channel(): return CommsQueue('submissions', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port) def get_token_store(key): return ExpiringSet(f"oauth_token_{key}", host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port, ttl=60 * 2)
class Ingester(ThreadedCoreBase): def __init__(self, datastore=None, logger=None, classification=None, redis=None, persistent_redis=None, metrics_name='ingester', config=None): super().__init__('assemblyline.ingester', logger, redis=redis, redis_persist=persistent_redis, datastore=datastore, config=config) # Cache the user groups self.cache_lock = threading.RLock() self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value: dict[str, int] = constants.PRIORITIES self.priority_range: dict[str, Tuple[int, int]] = constants.PRIORITY_RANGES self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.redis_persist) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.redis_persist) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.redis_persist) # Internal, timeout watch queue self.timeout_queue: PriorityQueue[str] = PriorityQueue( 'm-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.redis_persist) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis) if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="ingester") else: self.apm_client = None def try_run(self): threads_to_maintain = { 'Retries': self.handle_retries, 'Timeouts': self.handle_timeouts } threads_to_maintain.update({ f'Complete_{n}': self.handle_complete for n in range(COMPLETE_THREADS) }) threads_to_maintain.update( {f'Ingest_{n}': self.handle_ingest for n in range(INGEST_THREADS)}) threads_to_maintain.update( {f'Submit_{n}': self.handle_submit for n in range(SUBMIT_THREADS)}) self.maintain_threads(threads_to_maintain) def handle_ingest(self): cpu_mark = time.process_time() time_mark = time.time() # Move from ingest to unique and waiting queues. # While there are entries in the ingest queue we consume chunk_size # entries at a time and move unique entries to uniqueq / queued and # duplicates to their own queues / waiting. while self.running: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) message = self.ingest_queue.pop(timeout=1) cpu_mark = time.process_time() time_mark = time.time() if not message: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') try: if 'submission' in message: # A retried task task = IngestTask(message) else: # A new submission sub = MessageSubmission(message) task = IngestTask(dict( submission=sub, ingest_id=sub.sid, )) task.submission.sid = None # Reset to new random uuid # Write all input to the traffic queue self.traffic_queue.publish( SubmissionMessage({ 'msg': sub, 'msg_type': 'SubmissionIngested', 'sender': 'ingester', }).as_primitives()) except (ValueError, TypeError) as error: self.counter.increment('error') self.log.exception( f"Dropped ingest submission {message} because {str(error)}" ) # End of ingest message (value_error) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'value_error') continue self.ingest(task) # End of ingest message (success) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'success') def handle_submit(self): time_mark, cpu_mark = time.time(), time.process_time() while self.running: # noinspection PyBroadException try: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) # Check if there is room for more submissions length = self.scanning.length() if length >= self.config.core.ingester.max_inflight: self.sleep(0.1) time_mark, cpu_mark = time.time(), time.process_time() continue raw = self.unique_queue.blocking_pop(timeout=3) time_mark, cpu_mark = time.time(), time.process_time() if not raw: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') task = IngestTask(raw) # Check if we need to drop a file for capacity reasons, but only if the # number of files in flight is alreay over 80% if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop( task): # End of ingest message (dropped) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'dropped') continue if self.is_whitelisted(task): # End of ingest message (whitelisted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'whitelisted') continue # Check if this file has been previously processed. pprevious, previous, score, scan_key = None, None, None, None if not task.submission.params.ignore_cache: pprevious, previous, score, scan_key = self.check(task) else: scan_key = self.stamp_filescore_key(task) # If it HAS been previously processed, we are dealing with a resubmission # finalize will decide what to do, and put the task back in the queue # rewritten properly if we are going to run it again if previous: if not task.submission.params.services.resubmit and not pprevious: self.log.warning( f"No psid for what looks like a resubmission of " f"{task.submission.files[0].sha256}: {scan_key}") self.finalize(pprevious, previous, score, task) # End of ingest message (finalized) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'finalized') continue # We have decided this file is worth processing # Add the task to the scanning table, this is atomic across all submit # workers, so if it fails, someone beat us to the punch, record the file # as a duplicate then. if not self.scanning.add(scan_key, task.as_primitives()): self.log.debug('Duplicate %s', task.submission.files[0].sha256) self.counter.increment('duplicates') self.duplicate_queue.push(_dup_prefix + scan_key, task.as_primitives()) # End of ingest message (duplicate) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'duplicate') continue # We have managed to add the task to the scan table, so now we go # ahead with the submission process try: self.submit(task) # End of ingest message (submitted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'submitted') continue except Exception as _ex: # For some reason (contained in `ex`) we have failed the submission # The rest of this function is error handling/recovery ex = _ex # traceback = _ex.__traceback__ self.counter.increment('error') should_retry = True if isinstance(ex, CorruptedFileStoreException): self.log.error( "Submission for file '%s' failed due to corrupted " "filestore: %s" % (task.sha256, str(ex))) should_retry = False elif isinstance(ex, DataStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed due to " "data store error:\n%s" % (task.sha256, trace)) elif not isinstance(ex, FileStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed: %s" % (task.sha256, trace)) task = IngestTask(self.scanning.pop(scan_key)) if not task: self.log.error('No scanning entry for for %s', task.sha256) # End of ingest message (no_scan_entry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'no_scan_entry') continue if not should_retry: # End of ingest message (cannot_retry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'cannot_retry') continue self.retry(task, scan_key, ex) # End of ingest message (retry) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'retried') except Exception: self.log.exception("Unexpected error") # End of ingest message (exception) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'exception') def handle_complete(self): while self.running: result = self.complete_queue.pop(timeout=3) if not result: continue cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') sub = DatabaseSubmission(result) self.completed(sub) # End of ingest message (success) if self.apm_client: elasticapm.label(sid=sub.sid) self.apm_client.end_transaction('ingest_complete', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_retries(self): tasks = [] while self.sleep(0 if tasks else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_retries') tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(), num=100) for task in tasks: self.ingest_queue.push(task) # End of ingest message (success) if self.apm_client: elasticapm.label(retries=len(tasks)) self.apm_client.end_transaction('ingest_retries', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_timeouts(self): timeouts = [] while self.sleep(0 if timeouts else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_timeouts') timeouts = self.timeout_queue.dequeue_range( upper_limit=isotime.now(), num=100) for scan_key in timeouts: # noinspection PyBroadException try: actual_timeout = False # Remove the entry from the hash of submissions in progress. entry = self.scanning.pop(scan_key) if entry: actual_timeout = True self.log.error("Submission timed out for %s: %s", scan_key, str(entry)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if dup: actual_timeout = True while dup: self.log.error("Submission timed out for %s: %s", scan_key, str(dup)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if actual_timeout: self.counter.increment('timed_out') except Exception: self.log.exception("Problem timing out %s:", scan_key) # End of ingest message (success) if self.apm_client: elasticapm.label(timeouts=len(timeouts)) self.apm_client.end_transaction('ingest_timeouts', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def get_groups_from_user(self, username: str) -> List[str]: # Reset the group cache at the top of each hour if time.time() // HOUR_IN_SECONDS > self._user_groups_reset: self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS # Get the groups for this user if not known if username not in self._user_groups: user_data = self.datastore.user.get(username) if user_data: self._user_groups[username] = user_data.groups else: self._user_groups[username] = [] return self._user_groups[username] def ingest(self, task: IngestTask): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Task received for processing" ) # Load a snapshot of ingest parameters as of right now. max_file_size = self.config.submission.max_file_size param = task.params self.counter.increment('bytes_ingested', increment_by=task.file_size) self.counter.increment('submissions_ingested') if any(len(file.sha256) != 64 for file in task.submission.files): self.log.error( f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped") self.send_notification(task, failure="Invalid sha256", logfunc=self.log.warning) return # Clean up metadata strings, since we may delete some, iterate on a copy of the keys for key in list(task.submission.metadata.keys()): value = task.submission.metadata[key] meta_size = len(value) if meta_size > self.config.submission.max_metadata_length: self.log.info( f'[{task.ingest_id} :: {task.sha256}] ' f'Removing {key} from metadata because value is too big') task.submission.metadata.pop(key) if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop: task.failure = f"File too large ({task.file_size} > {max_file_size})" self._notify_drop(task) self.counter.increment('skipped') self.log.error( f"[{task.ingest_id} :: {task.sha256}] {task.failure}") return # Set the groups from the user, if they aren't already set if not task.params.groups: task.params.groups = self.get_groups_from_user( task.params.submitter) # Check if this file is already being processed self.stamp_filescore_key(task) pprevious, previous, score = None, None, None if not param.ignore_cache: pprevious, previous, score, _ = self.check(task, count_miss=False) # Assign priority. low_priority = self.is_low_priority(task) priority = param.priority if priority < 0: priority = self.priority_value['medium'] if score is not None: priority = self.priority_value['low'] for level, threshold in self.threshold_value.items(): if score >= threshold: priority = self.priority_value[level] break elif low_priority: priority = self.priority_value['low'] # Reduce the priority by an order of magnitude for very old files. current_time = now() if priority and self.expired( current_time - task.submission.time.timestamp(), 0): priority = (priority / 10) or 1 param.priority = priority # Do this after priority has been assigned. # (So we don't end up dropping the resubmission). if previous: self.counter.increment('duplicates') self.finalize(pprevious, previous, score, task) # On cache hits of any kind we want to send out a completed message self.traffic_queue.publish( SubmissionMessage({ 'msg': task.submission, 'msg_type': 'SubmissionCompleted', 'sender': 'ingester', }).as_primitives()) return if self.drop(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped") return if self.is_whitelisted(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted") return self.unique_queue.push(priority, task.as_primitives()) def check( self, task: IngestTask, count_miss=True ) -> Tuple[Optional[str], Optional[str], Optional[float], str]: key = self.stamp_filescore_key(task) with self.cache_lock: result = self.cache.get(key, None) if result: self.counter.increment('cache_hit_local') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Local cache hit') else: result = self.datastore.filescore.get_if_exists(key) if result: self.counter.increment('cache_hit') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Remote cache hit') else: if count_miss: self.counter.increment('cache_miss') return None, None, None, key with self.cache_lock: self.cache[key] = result current_time = now() age = current_time - result.time errors = result.errors if self.expired(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired" ) self.counter.increment('cache_expired') self.cache.pop(key, None) self.datastore.filescore.delete(key) return None, None, None, key elif self.stale(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale" ) self.counter.increment('cache_stale') return None, None, result.score, key return result.psid, result.sid, result.score, key def stop(self): super().stop() if self.apm_client: elasticapm.uninstrument() self.submit_client.stop() def stale(self, delta: float, errors: int): if errors: return delta >= self.config.core.ingester.incomplete_stale_after_seconds else: return delta >= self.config.core.ingester.stale_after_seconds @staticmethod def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str: if not sha256: sha256 = task.submission.files[0].sha256 key = task.submission.scan_key if not key: key = task.params.create_filescore_key(sha256) task.submission.scan_key = key return key def completed(self, sub: DatabaseSubmission): """Invoked when notified that a submission has completed.""" # There is only one file in the submissions we have made sha256 = sub.files[0].sha256 scan_key = sub.scan_key if not scan_key: self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission missing scan key") scan_key = sub.params.create_filescore_key(sha256) raw = self.scanning.pop(scan_key) psid = sub.params.psid score = sub.max_score sid = sub.sid if not raw: # Some other worker has already popped the scanning queue? self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission completed twice") return scan_key task = IngestTask(raw) task.submission.sid = sid errors = sub.error_count file_count = sub.file_count self.counter.increment('submissions_completed') self.counter.increment('files_completed', increment_by=file_count) self.counter.increment('bytes_completed', increment_by=task.file_size) with self.cache_lock: fs = self.cache[scan_key] = FileScore({ 'expiry_ts': now(self.config.core.ingester.cache_dtl * 24 * 60 * 60), 'errors': errors, 'psid': psid, 'score': score, 'sid': sid, 'time': now(), }) self.datastore.filescore.save(scan_key, fs) self.finalize(psid, sid, score, task) def exhaust() -> Iterable[IngestTask]: while True: res = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if res is None: break res = IngestTask(res) res.submission.sid = sid yield res # You may be tempted to remove the assignment to dups and use the # value directly in the for loop below. That would be a mistake. # The function finalize may push on the duplicate queue which we # are pulling off and so condensing those two lines creates a # potential infinite loop. dups = [dup for dup in exhaust()] for dup in dups: self.finalize(psid, sid, score, dup) return scan_key def send_notification(self, task: IngestTask, failure=None, logfunc=None): if logfunc is None: logfunc = self.log.info if failure: task.failure = failure failure = task.failure if failure: logfunc("%s: %s", failure, str(task.json())) if not task.submission.notification.queue: return note_queue = _notification_queue_prefix + task.submission.notification.queue threshold = task.submission.notification.threshold if threshold is not None and task.score is not None and task.score < threshold: return q = self.notification_queues.get(note_queue, None) if not q: self.notification_queues[note_queue] = q = NamedQueue( note_queue, self.redis_persist) q.push(task.as_primitives()) def expired(self, delta: float, errors) -> bool: if errors: return delta >= self.config.core.ingester.incomplete_expire_after_seconds else: return delta >= self.config.core.ingester.expire_after def drop(self, task: IngestTask) -> bool: priority = task.params.priority sample_threshold = self.config.core.ingester.sampling_at dropped = False if priority <= _min_priority: dropped = True else: for level, rng in self.priority_range.items(): if rng[0] <= priority <= rng[1] and level in sample_threshold: dropped = must_drop(self.unique_queue.count(*rng), sample_threshold[level]) break if not dropped: if task.file_size > self.config.submission.max_file_size or task.file_size == 0: dropped = True if task.params.never_drop or not dropped: return False task.failure = 'Skipped' self._notify_drop(task) self.counter.increment('skipped') return True def _notify_drop(self, task: IngestTask): self.send_notification(task) c12n = task.params.classification expiry = now_as_iso(86400) sha256 = task.submission.files[0].sha256 self.datastore.save_or_freshen_file(sha256, {'sha256': sha256}, expiry, c12n, redis=self.redis) def is_whitelisted(self, task: IngestTask): reason, hit = self.get_whitelist_verdict(self.whitelist, task) hit = {x: dotdump(safe_str(y)) for x, y in hit.items()} sha256 = task.submission.files[0].sha256 if not reason: with self.whitelisted_lock: reason = self.whitelisted.get(sha256, None) if reason: hit = 'cached' if reason: if hit != 'cached': with self.whitelisted_lock: self.whitelisted[sha256] = reason task.failure = "Whitelisting due to reason %s (%s)" % (dotdump( safe_str(reason)), hit) self._notify_drop(task) self.counter.increment('whitelisted') return reason def submit(self, task: IngestTask): self.submit_client.submit( submission_obj=task.submission, completed_queue=COMPLETE_QUEUE_NAME, ) self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key) self.log.info( f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis" ) def retry(self, task: IngestTask, scan_key: str, ex): current_time = now() retries = task.retries + 1 if retries > _max_retries: trace = '' if ex: trace = ': ' + get_stacktrace_info(ex) self.log.error( f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}' ) self.duplicate_queue.delete(_dup_prefix + scan_key) elif self.expired(current_time - task.ingest_time.timestamp(), 0): self.log.info( f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission' ) self.duplicate_queue.delete(_dup_prefix + scan_key) else: self.log.info( f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})' ) task.retries = retries self.retry_queue.push(int(now(_retry_delay)), task.as_primitives()) def finalize(self, psid: str, sid: str, score: float, task: IngestTask): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed") if psid: task.params.psid = psid task.score = score task.submission.sid = sid selected = task.params.services.selected resubmit_to = task.params.services.resubmit resubmit_selected = determine_resubmit_selected(selected, resubmit_to) will_resubmit = resubmit_selected and should_resubmit(score) if will_resubmit: task.extended_scan = 'submitted' task.params.psid = None if self.is_alert(task, score): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Notifying alerter " f"to {'update' if task.params.psid else 'create'} an alert") self.alert_queue.push(task.as_primitives()) self.send_notification(task) if will_resubmit: self.log.info( f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis" ) task.params.psid = sid task.submission.sid = None task.submission.scan_key = None task.params.services.resubmit = [] task.params.services.selected = resubmit_selected self.unique_queue.push(task.params.priority, task.as_primitives()) def is_alert(self, task: IngestTask, score: float) -> bool: if not task.params.generate_alert: return False if score < self.config.core.alerter.threshold: return False return True
def dispatch_submission(self, task: SubmissionTask): """ Find any files associated with a submission and dispatch them if they are not marked as in progress. If all files are finished, finalize the submission. This version of dispatch submission doesn't verify each result, but assumes that the dispatch table has been kept up to date by other components. Preconditions: - File exists in the filestore and file collection in the datastore - Submission is stored in the datastore """ submission = task.submission sid = submission.sid if not self.active_submissions.exists(sid): self.log.info(f"[{sid}] New submission received") self.active_submissions.add(sid, task.as_primitives()) else: self.log.info(f"[{sid}] Received a pre-existing submission, check if it is complete") # Refresh the watch, this ensures that this function will be called again # if something goes wrong with one of the files, and it never gets invoked by dispatch_file. self.timeout_watcher.touch(key=sid, timeout=int(self.config.core.dispatcher.timeout), queue=SUBMISSION_QUEUE, message={'sid': sid}) # Refresh the quota hold if submission.params.quota_item and submission.params.submitter: self.log.info(f"[{sid}] Submission will count towards {submission.params.submitter.upper()} quota") Hash('submissions-' + submission.params.submitter, self.redis_persist).add(sid, isotime.now_as_iso()) # Open up the file/service table for this submission dispatch_table = DispatchHash(submission.sid, self.redis, fetch_results=True) file_parents = dispatch_table.file_tree() # Load the file tree data as well # All the submission files, and all the file_tree files, to be sure we don't miss any incomplete children unchecked_hashes = [submission_file.sha256 for submission_file in submission.files] unchecked_hashes = list(set(unchecked_hashes) | set(file_parents.keys())) # Using the file tree we can recalculate the depth of any file depth_limit = self.config.submission.max_extraction_depth file_depth = depths_from_tree(file_parents) # Try to find all files, and extracted files, and create task objects for them # (we will need the file data anyway for checking the schedule later) max_files = len(submission.files) + submission.params.max_extracted unchecked_files = [] # Files that haven't been checked yet try: for sha, file_data in self.files.multiget(unchecked_hashes).items(): unchecked_files.append(FileTask(dict( sid=sid, min_classification=task.submission.classification, file_info=dict( magic=file_data.magic, md5=file_data.md5, mime=file_data.mime, sha1=file_data.sha1, sha256=file_data.sha256, size=file_data.size, type=file_data.type, ), depth=file_depth.get(sha, 0), max_files=max_files ))) except MultiKeyError as missing: errors = [] for file_sha in missing.keys: error = Error(dict( archive_ts=submission.archive_ts, expiry_ts=submission.expiry_ts, response=dict( message="Submission couldn't be completed due to missing file.", service_name="dispatcher", service_tool_version='4', service_version='4', status="FAIL_NONRECOVERABLE", ), sha256=file_sha, type='UNKNOWN' )) error_key = error.build_key(service_tool_version=sid) self.datastore.error.save(error_key, error) errors.append(error_key) return self.cancel_submission(task, errors, file_parents) # Files that have already been encountered, but may or may not have been processed yet # encountered_files = {file.sha256 for file in submission.files} pending_files = {} # Files that have not yet been processed # Track information about the results as we hit them file_scores: Dict[str, int] = {} # # Load the current state of the dispatch table in one go rather than one at a time in the loop prior_dispatches = dispatch_table.all_dispatches() # found should be added to the unchecked files if they haven't been encountered already for file_task in unchecked_files: sha = file_task.file_info.sha256 schedule = self.build_schedule(dispatch_table, submission, sha, file_task.file_info.type) while schedule: stage = schedule.pop(0) for service_name in stage: # Only active services should be in this dict, so if a service that was placed in the # schedule is now missing it has been disabled or taken offline. service = self.scheduler.services.get(service_name) if not service: continue # If the service is still marked as 'in progress' runtime = time.time() - prior_dispatches.get(sha, {}).get(service_name, 0) if runtime < service.timeout: pending_files[sha] = file_task continue # It hasn't started, has timed out, or is finished, see if we have a result result_row = dispatch_table.finished(sha, service_name) # No result found, mark the file as incomplete if not result_row: pending_files[sha] = file_task continue if not submission.params.ignore_filtering and result_row.drop: schedule.clear() # The process table is marked that a service has been abandoned due to errors if result_row.is_error: continue # Collect information about the result file_scores[sha] = file_scores.get(sha, 0) + result_row.score # Using the file tree find the most shallow parent of the given file def lowest_parent(_sha): # A root file won't have any parents in the dict if _sha not in file_parents or None in file_parents[_sha]: return None return min((file_depth.get(parent, depth_limit), parent) for parent in file_parents[_sha])[1] # Filter out things over the depth limit pending_files = {sha: ft for sha, ft in pending_files.items() if ft.depth < depth_limit} # Filter out files based on the extraction limits pending_files = {sha: ft for sha, ft in pending_files.items() if dispatch_table.add_file(sha, max_files, lowest_parent(sha))} # If there are pending files, then at least one service, on at least one # file isn't done yet, and hasn't been filtered by any of the previous few steps # poke those files if pending_files: self.log.debug(f"[{sid}] Dispatching {len(pending_files)} files: {list(pending_files.keys())}") for file_task in pending_files.values(): self.file_queue.push(file_task.as_primitives()) else: self.log.debug(f"[{sid}] Finalizing submission.") max_score = max(file_scores.values()) if file_scores else 0 # Submissions with no results have no score self.finalize_submission(task, max_score, file_scores.keys())
class DispatchClient: def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None): self.config = forge.get_config() self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = redis_persist or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ds = datastore or forge.get_datastore(self.config) self.log = logger or logging.getLogger("assemblyline.dispatching.client") self.results = self.ds.result self.errors = self.ds.error self.files = self.ds.file self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist) self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis) self.service_data = cast(Dict[str, Service], CachedObject(self._get_services)) self.dispatcher_data = [] self.dispatcher_data_age = 0.0 self.dead_dispatchers = [] @weak_lru(maxsize=128) def _get_queue_from_cache(self, name): return NamedQueue(name, host=self.redis, ttl=QUEUE_EXPIRY) def _get_services(self): # noinspection PyUnresolvedReferences return {x.name: x for x in self.ds.list_all_services(full=True)} def is_dispatcher(self, dispatcher_id) -> bool: if dispatcher_id in self.dead_dispatchers: return False if time.time() - self.dispatcher_data_age > 120 or dispatcher_id not in self.dispatcher_data: self.dispatcher_data = Dispatcher.all_instances(self.redis_persist) self.dispatcher_data_age = time.time() if dispatcher_id in self.dispatcher_data: return True else: self.dead_dispatchers.append(dispatcher_id) return False def dispatch_bundle(self, submission: Submission, results: Dict[str, Result], file_infos: Dict[str, File], file_tree, errors: Dict[str, Error], completed_queue: str = None): """Insert a bundle into the dispatching system and continue scanning of its files Prerequisites: - Submission, results, file_infos and errors should already be saved in the datastore - Files should already be in the filestore """ self.submission_queue.push(dict( submission=submission.as_primitives(), results=results, file_infos=file_infos, file_tree=file_tree, errors=errors, completed_queue=completed_queue, )) def dispatch_submission(self, submission: Submission, completed_queue: str = None): """Insert a submission into the dispatching system. Note: You probably actually want to use the SubmissionTool Prerequsits: - submission should already be saved in the datastore - files should already be in the datastore and filestore """ self.submission_queue.push(dict( submission=submission.as_primitives(), completed_queue=completed_queue, )) def outstanding_services(self, sid) -> Dict[str, int]: """ List outstanding services for a given submission and the number of file each of them still have to process. :param sid: Submission ID :return: Dictionary of services and number of files remaining per services e.g. {"SERVICE_NAME": 1, ... } """ dispatcher_id = self.submission_assignments.get(sid) if dispatcher_id: queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue") queue = NamedQueue(queue_name, host=self.redis, ttl=30) command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis) command_queue.push(DispatcherCommandMessage({ 'kind': LIST_OUTSTANDING, 'payload_data': ListOutstanding({ 'response_queue': queue_name, 'submission': sid }) }).as_primitives()) return queue.pop(timeout=30) return {} @elasticapm.capture_span(span_type='dispatch_client') def request_work(self, worker_id, service_name, service_version, timeout: float = 60, blocking=True, low_priority=False) -> Optional[ServiceTask]: """Pull work from the service queue for the service in question. :param service_version: :param worker_id: :param service_name: Which service needs work. :param timeout: How many seconds to block before returning if blocking is true. :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately :return: The job found, and a boolean value indicating if this is the first time this task has been returned by request_work. """ start = time.time() remaining = timeout while int(remaining) > 0: work = self._request_work(worker_id, service_name, service_version, blocking=blocking, timeout=remaining, low_priority=low_priority) if work or not blocking: return work remaining = timeout - (time.time() - start) return None def _request_work(self, worker_id, service_name, service_version, timeout, blocking, low_priority=False) -> Optional[ServiceTask]: # For when we recursively retry on bad task dequeue-ing if int(timeout) <= 0: self.log.info(f"{service_name}:{worker_id} no task returned [timeout]") return None # Get work from the queue work_queue = get_service_queue(service_name, self.redis) if blocking: result = work_queue.blocking_pop(timeout=int(timeout), low_priority=low_priority) else: if low_priority: result = work_queue.unpush(1) else: result = work_queue.pop(1) if result: result = result[0] if not result: self.log.info(f"{service_name}:{worker_id} no task returned: [empty message]") return None task = ServiceTask(result) task.metadata['worker__'] = worker_id dispatcher = task.metadata['dispatcher__'] if not self.is_dispatcher(dispatcher): self.log.info(f"{service_name}:{worker_id} no task returned: [task from dead dispatcher]") return None if self.running_tasks.add(task.key(), task.as_primitives()): self.log.info(f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found") start_queue = self._get_queue_from_cache(DISPATCH_START_EVENTS + dispatcher) start_queue.push((task.sid, task.fileinfo.sha256, service_name, worker_id)) return task return None @elasticapm.capture_span(span_type='dispatch_client') def service_finished(self, sid: str, result_key: str, result: Result, temporary_data: Optional[Dict[str, Any]] = None): """Notifies the dispatcher of service completion, and possible new files to dispatch.""" # Make sure the dispatcher knows we were working on this task task_key = ServiceTask.make_key(sid=sid, service_name=result.response.service_name, sha=result.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning(f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified " f"task in its set of running tasks while processing successful results.") return task = ServiceTask(task) # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the # most distant expiry time to prevent pulling it out from under another submission too early if result.is_empty(): # Empty Result will not be archived therefore result.archive_ts drives their deletion self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts}) else: while True: old, version = self.ds.result.get_if_exists( result_key, archive_access=self.config.datastore.ilm.update_archive, version=True) if old: if old.expiry_ts and result.expiry_ts: result.expiry_ts = max(result.expiry_ts, old.expiry_ts) else: result.expiry_ts = None try: self.ds.result.save(result_key, result, version=version) break except VersionConflictException as vce: self.log.info(f"Retrying to save results due to version conflict: {str(vce)}") # Send the result key to any watching systems msg = {'status': 'OK', 'cache_key': result_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w, host=self.redis).push(msg) # Save the tags tags = [] for section in result.result.sections: tags.extend(tag_dict_to_list(flatten(section.tags.as_primitives()))) # Pull out file names if we have them file_names = {} for extracted_data in result.response.extracted: if extracted_data.name: file_names[extracted_data.sha256] = extracted_data.name # dispatcher = task.metadata['dispatcher__'] result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher) ex_ts = result.expiry_ts.strftime(DATEFORMAT) if result.expiry_ts else result.archive_ts.strftime(DATEFORMAT) result_queue.push({ # 'service_task': task.as_primitives(), # 'result': result.as_primitives(), 'sid': task.sid, 'sha256': result.sha256, 'service_name': task.service_name, 'service_version': result.response.service_version, 'service_tool_version': result.response.service_tool_version, 'archive_ts': result.archive_ts.strftime(DATEFORMAT), 'expiry_ts': ex_ts, 'result_summary': { 'key': result_key, 'drop': result.drop_file, 'score': result.result.score, 'children': [r.sha256 for r in result.response.extracted], }, 'tags': tags, 'extracted_names': file_names, 'temporary_data': temporary_data }) @elasticapm.capture_span(span_type='dispatch_client') def service_failed(self, sid: str, error_key: str, error: Error): task_key = ServiceTask.make_key(sid=sid, service_name=error.response.service_name, sha=error.sha256) task = self.running_tasks.pop(task_key) if not task: self.log.warning(f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified " f"task in its set of running tasks while processing an error.") return task = ServiceTask(task) self.log.debug(f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error.") if error.response.status == "FAIL_NONRECOVERABLE": # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user self.errors.save(error_key, error) # Send the result key to any watching systems msg = {'status': 'FAIL', 'cache_key': error_key} for w in self._get_watcher_list(task.sid).members(): NamedQueue(w, host=self.redis).push(msg) dispatcher = task.metadata['dispatcher__'] result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher) result_queue.push({ 'sid': task.sid, 'service_task': task.as_primitives(), 'error': error.as_primitives(), 'error_key': error_key }) def setup_watch_queue(self, sid: str) -> Optional[str]: """ This function takes a submission ID as a parameter and creates a unique queue where all service result keys for that given submission will be returned to as soon as they come in. If the submission is in the middle of processing, this will also send all currently received keys through the specified queue so the client that requests the watch queue is up to date. :param sid: Submission ID :return: The name of the watch queue that was created """ dispatcher_id = self.submission_assignments.get(sid) if dispatcher_id: queue_name = reply_queue_name(prefix="D", suffix="WQ") command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis) command_queue.push(DispatcherCommandMessage({ 'kind': CREATE_WATCH, 'payload_data': CreateWatch({ 'queue_name': queue_name, 'submission': sid }) }).as_primitives()) return queue_name def _get_watcher_list(self, sid): return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
def __init__(self, datastore, logger, classification=None, redis=None, persistent_redis=None, metrics_name='ingester'): self.datastore = datastore self.log = logger # Cache the user groups self.cache_lock = threading.RLock( ) # TODO are middle man instances single threaded now? self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Create a config cache that will refresh config values periodically self.config = forge.CachedObject(forge.get_config) # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value = constants.PRIORITIES self.priority_range = constants.PRIORITY_RANGES self.threshold_value = constants.PRIORITY_THRESHOLDS # Connect to the redis servers self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.persistent_redis = persistent_redis or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.persistent_redis) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(_completeq_name, self.redis) # Internal. Dropped entries are placed on this queue. # self.drop_queue = NamedQueue('m-drop', self.persistent_redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.persistent_redis) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.persistent_redis) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.persistent_redis) # Internal, timeout watch queue self.timeout_queue = PriorityQueue('m-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.persistent_redis) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis)
class ServiceUpdater(CoreBase): def __init__(self, redis_persist=None, redis=None, logger=None, datastore=None): super().__init__('assemblyline.service.updater', logger=logger, datastore=datastore, redis_persist=redis_persist, redis=redis) if not FILE_UPDATE_DIRECTORY: raise RuntimeError( "The updater process must be run within the orchestration environment, " "the update volume must be mounted, and the path to the volume must be " "set in the environment variable FILE_UPDATE_DIRECTORY. Setting " "FILE_UPDATE_DIRECTORY directly may be done for testing.") # The directory where we want working temporary directories to be created. # Building our temporary directories in the persistent update volume may # have some performance down sides, but may help us run into fewer docker FS overlay # cleanup issues. Try to flush it out every time we start. This service should # be a singleton anyway. self.temporary_directory = os.path.join(FILE_UPDATE_DIRECTORY, '.tmp') shutil.rmtree(self.temporary_directory, ignore_errors=True) os.makedirs(self.temporary_directory) self.container_update = Hash('container-update', self.redis_persist) self.services = Hash('service-updates', self.redis_persist) self.latest_service_tags = Hash('service-tags', self.redis_persist) self.running_updates: Dict[str, Thread] = {} # Prepare a single threaded scheduler self.scheduler = sched.scheduler() # if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE: self.controller = KubernetesUpdateInterface( prefix='alsvc_', namespace=NAMESPACE, priority_class='al-core-priority') else: self.controller = DockerUpdateInterface() def sync_services(self): """Download the service list and make sure our settings are up to date""" self.scheduler.enter(SERVICE_SYNC_INTERVAL, 0, self.sync_services) existing_services = (set(self.services.keys()) | set(self.container_update.keys()) | set(self.latest_service_tags.keys())) discovered_services = [] # Get all the service data for service in self.datastore.list_all_services(full=True): discovered_services.append(service.name) # Ensure that any disabled services are not being updated if not service.enabled and self.services.exists(service.name): self.log.info(f"Service updates disabled for {service.name}") self.services.pop(service.name) if not service.enabled: continue # Ensure that any enabled services with an update config are being updated stage = self.get_service_stage(service.name) record = self.services.get(service.name) if stage in UPDATE_STAGES and service.update_config: # Stringify and hash the the current update configuration config_hash = hash( json.dumps(service.update_config.as_primitives())) # If we can update, but there is no record, create one if not record: self.log.info( f"Service updates enabled for {service.name}") self.services.add( service.name, dict( next_update=now_as_iso(), previous_update=now_as_iso(-10**10), config_hash=config_hash, sha256=None, )) else: # If there is a record, check that its configuration hash is still good # If an update is in progress, it may overwrite this, but we will just come back # and reapply this again in the iteration after that if record.get('config_hash', None) != config_hash: record['next_update'] = now_as_iso() record['config_hash'] = config_hash self.services.set(service.name, record) if stage == ServiceStage.Update: if (record and record.get('sha256', None) is not None) or not service.update_config: self._service_stage_hash.set(service.name, ServiceStage.Running) # Remove services we have locally or in redis that have been deleted from the database for stray_service in existing_services - set(discovered_services): self.log.info(f"Service updates disabled for {stray_service}") self.services.pop(stray_service) self._service_stage_hash.pop(stray_service) self.container_update.pop(stray_service) self.latest_service_tags.pop(stray_service) def container_updates(self): """Go through the list of services and check what are the latest tags for it""" self.scheduler.enter(UPDATE_CHECK_INTERVAL, 0, self.container_updates) for service_name, update_data in self.container_update.items().items(): self.log.info( f"Service {service_name} is being updated to version {update_data['latest_tag']}..." ) # Load authentication params username = None password = None auth = update_data['auth'] or {} if auth: username = auth.get('username', None) password = auth.get('password', None) try: self.controller.launch( name=service_name, docker_config=DockerConfig( dict(allow_internet_access=True, registry_username=username, registry_password=password, cpu_cores=1, environment=[], image=update_data['image'], ports=[])), mounts=[], env={ "SERVICE_TAG": update_data['latest_tag'], "SERVICE_API_HOST": os.environ.get('SERVICE_API_HOST', "http://al_service_server:5003"), "REGISTER_ONLY": 'true' }, network='al_registration', blocking=True) latest_tag = update_data['latest_tag'].replace('stable', '') service_key = f"{service_name}_{latest_tag}" if self.datastore.service.get_if_exists(service_key): operations = [(self.datastore.service_delta.UPDATE_SET, 'version', latest_tag)] if self.datastore.service_delta.update( service_name, operations): # Update completed, cleanup self.log.info( f"Service {service_name} update successful!") else: self.log.error( f"Service {service_name} has failed to update because it cannot set " f"{latest_tag} as the new version. Update procedure cancelled..." ) else: self.log.error( f"Service {service_name} has failed to update because resulting " f"service key ({service_key}) does not exist. Update procedure cancelled..." ) except Exception as e: self.log.error( f"Service {service_name} has failed to update. Update procedure cancelled... [{str(e)}]" ) self.container_update.pop(service_name) def container_versions(self): """Go through the list of services and check what are the latest tags for it""" self.scheduler.enter(CONTAINER_CHECK_INTERVAL, 0, self.container_versions) for service in self.datastore.list_all_services(full=True): if not service.enabled: continue image_name, tag_name, auth = get_latest_tag_for_service( service, self.config, self.log) self.latest_service_tags.set( service.name, { 'auth': auth, 'image': image_name, service.update_channel: tag_name }) def try_run(self): """Run the scheduler loop until told to stop.""" # Do an initial call to the main methods, who will then be registered with the scheduler self.sync_services() self.update_services() self.container_versions() self.container_updates() self.heartbeat() # Run as long as we need to while self.running: delay = self.scheduler.run(False) time.sleep(min(delay, 0.1)) def heartbeat(self): """Periodically touch a file on disk. Since tasks are run serially, the delay between touches will be the maximum of HEARTBEAT_INTERVAL and the longest running task. """ if self.config.logging.heartbeat_file: self.scheduler.enter(HEARTBEAT_INTERVAL, 0, self.heartbeat) super().heartbeat() def update_services(self): """Check if we need to update any services. Spin off a thread to actually perform any updates. Don't allow multiple threads per service. """ self.scheduler.enter(UPDATE_CHECK_INTERVAL, 0, self.update_services) # Check for finished update threads self.running_updates = { name: thread for name, thread in self.running_updates.items() if thread.is_alive() } # Check if its time to try to update the service for service_name, data in self.services.items().items(): if data['next_update'] <= now_as_iso( ) and service_name not in self.running_updates: self.log.info(f"Time to update {service_name}") self.running_updates[service_name] = Thread( target=self.run_update, kwargs=dict(service_name=service_name)) self.running_updates[service_name].start() def run_update(self, service_name): """Common setup and tear down for all update types.""" # noinspection PyBroadException try: # Check for new update with service specified update method service = self.datastore.get_service_with_delta(service_name) update_method = service.update_config.method update_data = self.services.get(service_name) update_hash = None try: # Actually run the update method if update_method == 'run': update_hash = self.do_file_update( service=service, previous_hash=update_data['sha256'], previous_update=update_data['previous_update']) elif update_method == 'build': update_hash = self.do_build_update() # If we have performed an update, write that data if update_hash is not None and update_hash != update_data[ 'sha256']: update_data['sha256'] = update_hash update_data['previous_update'] = now_as_iso() else: update_hash = None finally: # Update the next service update check time, don't update the config_hash, # as we don't want to disrupt being re-run if our config has changed during this run update_data['next_update'] = now_as_iso( service.update_config.update_interval_seconds) self.services.set(service_name, update_data) if update_hash: self.log.info( f"New update applied for {service_name}. Restarting service." ) self.controller.restart(service_name=service_name) except BaseException: self.log.exception( "An error occurred while running an update for: " + service_name) def do_build_update(self): """Update a service by building a new container to run.""" raise NotImplementedError() def do_file_update(self, service, previous_hash, previous_update): """Update a service by running a container to get new files.""" temp_directory = tempfile.mkdtemp(dir=self.temporary_directory) chmod(temp_directory, 0o777) input_directory = os.path.join(temp_directory, 'input_directory') output_directory = os.path.join(temp_directory, 'output_directory') service_dir = os.path.join(FILE_UPDATE_DIRECTORY, service.name) image_variables = defaultdict(str) image_variables.update(self.config.services.image_variables) try: # Use chmod directly to avoid effects of umask os.makedirs(input_directory) chmod(input_directory, 0o755) os.makedirs(output_directory) chmod(output_directory, 0o777) username = self.ensure_service_account() with temporary_api_key(self.datastore, username) as api_key: # Write out the parameters we want to pass to the update container with open(os.path.join(input_directory, 'config.yaml'), 'w') as fh: yaml.safe_dump( { 'previous_update': previous_update, 'previous_hash': previous_hash, 'sources': [ x.as_primitives() for x in service.update_config.sources ], 'api_user': username, 'api_key': api_key, 'ui_server': UI_SERVER }, fh) # Run the update container run_options = service.update_config.run_options run_options.image = string.Template( run_options.image).safe_substitute(image_variables) self.controller.launch( name=service.name, docker_config=run_options, mounts=[ { 'volume': FILE_UPDATE_VOLUME, 'source_path': os.path.relpath(temp_directory, start=FILE_UPDATE_DIRECTORY), 'dest_path': '/mount/' }, ], env={ 'UPDATE_CONFIGURATION_PATH': '/mount/input_directory/config.yaml', 'UPDATE_OUTPUT_PATH': '/mount/output_directory/' }, network=f'service-net-{service.name}', blocking=True, ) # Read out the results from the output container results_meta_file = os.path.join(output_directory, 'response.yaml') if not os.path.exists(results_meta_file) or not os.path.isfile( results_meta_file): self.log.warning( f"Update produced no output for {service.name}") return None with open(results_meta_file) as rf: results_meta = yaml.safe_load(rf) update_hash = results_meta.get('hash', None) # Erase the results meta file os.unlink(results_meta_file) # Get a timestamp for now, and switch it to basic format representation of time # Still valid iso 8601, and : is sometimes a restricted character timestamp = now_as_iso().replace(":", "") # FILE_UPDATE_DIRECTORY/{service_name} is the directory mounted to the service, # the service sees multiple directories in that directory, each with a timestamp destination_dir = os.path.join(service_dir, service.name + '_' + timestamp) shutil.move(output_directory, destination_dir) # Remove older update files, due to the naming scheme, older ones will sort first lexically existing_folders = [] for folder_name in os.listdir(service_dir): folder_path = os.path.join(service_dir, folder_name) if os.path.isdir(folder_path) and folder_name.startswith( service.name): existing_folders.append(folder_name) existing_folders.sort() self.log.info( f'There are {len(existing_folders)} update folders for {service.name} in cache.' ) if len(existing_folders) > UPDATE_FOLDER_LIMIT: extra_count = len(existing_folders) - UPDATE_FOLDER_LIMIT self.log.info( f'We will only keep {UPDATE_FOLDER_LIMIT} updates, deleting {extra_count}.' ) for extra_folder in existing_folders[:extra_count]: # noinspection PyBroadException try: shutil.rmtree( os.path.join(service_dir, extra_folder)) except Exception: self.log.exception( 'Failed to delete update folder') return update_hash finally: # If the working directory is still there for any reason erase it shutil.rmtree(temp_directory, ignore_errors=True) def ensure_service_account(self): """Check that the update service account exists, if it doesn't, create it.""" uname = 'update_service_account' if self.datastore.user.get_if_exists(uname): return uname user_data = User({ "agrees_with_tos": "NOW", "classification": "RESTRICTED", "name": "Update Account", "password": get_password_hash(''.join( random.choices(string.ascii_letters, k=20))), "uname": uname, "type": ["signature_importer"] }) self.datastore.user.save(uname, user_data) self.datastore.user_settings.save(uname, UserSettings()) return uname
from assemblyline.remote.datatypes import get_client from assemblyline.remote.datatypes.events import EventSender from assemblyline.remote.datatypes.hash import Hash from assemblyline_core.updater.helper import get_latest_tag_for_service from assemblyline_ui.api.base import api_login, make_api_response, make_file_response, make_subapi_blueprint from assemblyline_ui.api.v4.signature import _reset_service_updates from assemblyline_ui.config import LOGGER, STORAGE, config, CLASSIFICATION as Classification SUB_API = 'service' service_api = make_subapi_blueprint(SUB_API, api_version=4) service_api._doc = "Manage the different services" latest_service_tags = Hash( 'service-tags', get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, )) service_update = Hash( 'container-update', get_client( host=config.core.redis.persistent.host, port=config.core.redis.persistent.port, private=False, )) event_sender = EventSender('changes.services', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port)