def __init__(self, config=None): super().__init__('assemblyline.heartbeat_manager') self.config = config or forge.get_config() self.datastore = forge.get_datastore() self.metrics_queue = CommsQueue(METRICS_QUEUE) self.scheduler = BackgroundScheduler(daemon=True) self.hm = HeartbeatFormatter("heartbeat_manager", self.log, config=self.config) self.counters_lock = Lock() self.counters = {} self.rolling_window = {} self.window_ttl = {} self.ttl = self.config.core.metrics.export_interval * 2 self.window_size = int(60 / self.config.core.metrics.export_interval) if self.window_size != 60 / self.config.core.metrics.export_interval: self.log.warning( "Cannot calculate a proper window size for reporting heartbeats. " "Metrics reported during hearbeat will be wrong.") if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="heartbeat_manager") else: self.apm_client = None
def __init__(self, sender, log, config=None, redis=None): self.sender = sender self.log = log self.config = config or forge.get_config() self.datastore = forge.get_datastore(self.config) self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.status_queue = CommsQueue(STATUS_QUEUE, self.redis) self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH, self.redis_persist) self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ingest_scanning = Hash('m-scanning-table', self.redis_persist) self.ingest_unique_queue = PriorityQueue('m-unique', self.redis_persist) self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) constants = forge.get_constants(self.config) self.c_rng = constants.PRIORITY_RANGES['critical'] self.h_rng = constants.PRIORITY_RANGES['high'] self.m_rng = constants.PRIORITY_RANGES['medium'] self.l_rng = constants.PRIORITY_RANGES['low'] self.c_s_at = self.config.core.ingester.sampling_at['critical'] self.h_s_at = self.config.core.ingester.sampling_at['high'] self.m_s_at = self.config.core.ingester.sampling_at['medium'] self.l_s_at = self.config.core.ingester.sampling_at['low'] self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS} if self.config.core.expiry.batch_delete: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}/DAY]" else: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}]" self.scheduler = BackgroundScheduler(daemon=True) self.scheduler.add_job( self._reload_expiry_queues, 'interval', seconds=self.config.core.metrics.export_interval * 4) self.scheduler.start()
def test_submission_namespace(datastore, sio): submission_queue = CommsQueue('submissions', private=True) monitoring = get_random_id() ingested = random_model_obj(SubmissionMessage).as_primitives() ingested['msg_type'] = "SubmissionIngested" received = random_model_obj(SubmissionMessage).as_primitives() received['msg_type'] = "SubmissionReceived" queued = random_model_obj(SubmissionMessage).as_primitives() queued['msg_type'] = "SubmissionQueued" started = random_model_obj(SubmissionMessage).as_primitives() started['msg_type'] = "SubmissionStarted" test_res_array = [] @sio.on('monitoring', namespace='/submissions') def on_monitoring(data): # Confirmation that we are waiting for status messages test_res_array.append(('on_monitoring', data == monitoring)) @sio.on('SubmissionIngested', namespace='/submissions') def on_submission_ingested(data): test_res_array.append( ('on_submission_ingested', data == ingested['msg'])) @sio.on('SubmissionReceived', namespace='/submissions') def on_submission_received(data): test_res_array.append( ('on_submission_received', data == received['msg'])) @sio.on('SubmissionQueued', namespace='/submissions') def on_submission_queued(data): test_res_array.append(('on_submission_queued', data == queued['msg'])) @sio.on('SubmissionStarted', namespace='/submissions') def on_submission_started(data): test_res_array.append( ('on_submission_started', data == started['msg'])) try: sio.emit('monitor', monitoring, namespace='/submissions') sio.sleep(1) submission_queue.publish(ingested) submission_queue.publish(received) submission_queue.publish(queued) submission_queue.publish(started) start_time = time.time() while len(test_res_array) < 5 and time.time() - start_time < 5: sio.sleep(0.1) assert len(test_res_array) == 5 for test, result in test_res_array: if not result: pytest.fail(f"{test} failed.") finally: sio.disconnect()
def test_comms_queue(redis_connection): if redis_connection: from assemblyline.remote.datatypes.queues.comms import CommsQueue def publish_messages(message_list): time.sleep(0.1) with CommsQueue('test-comms-queue') as cq_p: for message in message_list: cq_p.publish(message) msg_list = ["bob", 1, {"bob": 1}, [1, 2, 3], None, "Nice!", "stop"] t = Thread(target=publish_messages, args=(msg_list, )) t.start() with CommsQueue('test-comms-queue') as cq: x = 0 for msg in cq.listen(): if msg == "stop": break assert msg == msg_list[x] x += 1 t.join() assert not t.is_alive()
def test_alert_created(datastore, client): alert_queue = CommsQueue('alerts', private=True) created = random_model_obj(AlertMessage) created.msg_type = "AlertCreated" updated = random_model_obj(AlertMessage) updated.msg_type = "AlertUpdated" test_res_array = [] def alerter_created_callback(data): test_res_array.append(('created', created['msg'] == data)) def alerter_updated_callback(data): test_res_array.append(('updated', updated['msg'] == data)) def publish_thread(): time.sleep(1) alert_queue.publish(created.as_primitives()) alert_queue.publish(updated.as_primitives()) threading.Thread(target=publish_thread).start() client.socketio.listen_on_alerts_messages( alert_created_callback=alerter_created_callback, alert_updated_callback=alerter_updated_callback, timeout=2) assert len(test_res_array) == 2 for test, result in test_res_array: if not result: pytest.fail("{} failed.".format(test))
def save_alert(datastore, counter, logger, alert, psid): def create_alert(): msg_type = "AlertCreated" datastore.alert.save(alert['alert_id'], alert) logger.info(f"Alert {alert['alert_id']} has been created.") counter.increment('created') ret_val = 'create' return msg_type, ret_val if psid: try: msg_type = "AlertUpdated" perform_alert_update(datastore, logger, alert) counter.increment('updated') ret_val = 'update' except AlertMissingError as e: logger.info( f"{str(e)}. Creating a new alert [{alert['alert_id']}]...") msg_type, ret_val = create_alert() else: msg_type, ret_val = create_alert() msg = AlertMessage({ "msg": alert, "msg_type": msg_type, "sender": "alerter" }) CommsQueue('alerts').publish(msg.as_primitives()) return ret_val
def test_alert_namespace(datastore, sio): alert_queue = CommsQueue('alerts', private=True) test_id = get_random_id() created = random_model_obj(AlertMessage) created.msg_type = "AlertCreated" updated = random_model_obj(AlertMessage) updated.msg_type = "AlertUpdated" test_res_array = [] @sio.on('monitoring', namespace='/alerts') def on_monitoring(data): # Confirmation that we are waiting for alerts test_res_array.append(('on_monitoring', data == test_id)) @sio.on('AlertCreated', namespace='/alerts') def on_alert_created(data): test_res_array.append( ('on_alert_created', data == created.as_primitives()['msg'])) @sio.on('AlertUpdated', namespace='/alerts') def on_alert_updated(data): test_res_array.append( ('on_alert_updated', data == updated.as_primitives()['msg'])) try: sio.emit('alert', test_id, namespace='/alerts') sio.sleep(1) alert_queue.publish(created.as_primitives()) alert_queue.publish(updated.as_primitives()) start_time = time.time() while len(test_res_array) < 3 or time.time() - start_time < 5: sio.sleep(0.1) assert len(test_res_array) == 3 for test, result in test_res_array: if not result: pytest.fail(f"{test} failed.") finally: sio.disconnect()
def try_run(self): # If our connection to the metrics database requires a custom ca cert, prepare it ca_certs = None if self.config.core.metrics.elasticsearch.host_certificates: with tempfile.NamedTemporaryFile(delete=False) as ca_certs_file: ca_certs = ca_certs_file.name ca_certs_file.write(self.config.core.metrics.elasticsearch.host_certificates.encode()) self.metrics_queue = CommsQueue(METRICS_QUEUE) self.es = elasticsearch.Elasticsearch(hosts=self.elastic_hosts, connection_class=elasticsearch.RequestsHttpConnection, ca_certs=ca_certs) self.scheduler.add_job(self._create_aggregated_metrics, 'interval', seconds=60) self.scheduler.start() while self.running: for msg in self.metrics_queue.listen(): # APM Transaction start if self.apm_client: self.apm_client.begin_transaction('metrics') m_name = msg.pop('name', None) m_type = msg.pop('type', None) msg.pop('host', None) msg.pop('instance', None) self.log.debug(f"Received {m_type.upper()} metrics message") if not m_name or not m_type: # APM Transaction end if self.apm_client: self.apm_client.end_transaction('process_message', 'invalid_message') continue with self.counters_lock: c_key = (m_name, m_type) if c_key not in self.counters or m_type in NON_AGGREGATED: self.counters[c_key] = Counter(msg) else: self.counters[c_key].update(Counter(msg)) # APM Transaction end if self.apm_client: self.apm_client.end_transaction('process_message', 'success')
def test_status_messages(datastore, client): status_queue = CommsQueue('status', private=True) test_res_array = [] alerter_hb_msg = random_model_obj(AlerterMessage).as_primitives() dispatcher_hb_msg = random_model_obj(DispatcherMessage).as_primitives() expiry_hb_msg = random_model_obj(ExpiryMessage).as_primitives() ingest_hb_msg = random_model_obj(IngestMessage).as_primitives() service_hb_msg = random_model_obj(ServiceMessage).as_primitives() service_timing_msg = random_model_obj(ServiceTimingMessage).as_primitives() def alerter_callback(data): test_res_array.append(('alerter', alerter_hb_msg['msg'] == data)) def dispatcher_callback(data): test_res_array.append(('dispatcher', dispatcher_hb_msg['msg'] == data)) def expiry_callback(data): test_res_array.append(('expiry', expiry_hb_msg['msg'] == data)) def ingest_callback(data): test_res_array.append(('ingest', ingest_hb_msg['msg'] == data)) def service_callback(data): test_res_array.append(('service', service_hb_msg['msg'] == data)) def service_timing_callback(data): test_res_array.append( ('service_timing', service_timing_msg['msg'] == data)) def publish_thread(): time.sleep(1) status_queue.publish(alerter_hb_msg) status_queue.publish(dispatcher_hb_msg) status_queue.publish(expiry_hb_msg) status_queue.publish(ingest_hb_msg) status_queue.publish(service_hb_msg) status_queue.publish(service_timing_msg) threading.Thread(target=publish_thread).start() client.socketio.listen_on_status_messages( alerter_msg_callback=alerter_callback, dispatcher_msg_callback=dispatcher_callback, expiry_msg_callback=expiry_callback, ingest_msg_callback=ingest_callback, service_msg_callback=service_callback, service_timing_msg_callback=service_timing_callback, timeout=2) assert len(test_res_array) == 6 for test, result in test_res_array: if not result: pytest.fail("{} failed.".format(test))
def monitor_system_status(self): q = CommsQueue('status', private=True) try: for msg in q.listen(): if self.stop: break message = msg['msg'] msg_type = msg['msg_type'] self.socketio.emit(msg_type, message, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - Sending {msg_type} event to all connected users." ) except Exception: LOGGER.exception(f"SocketIO:{self.namespace}") finally: LOGGER.info( f"SocketIO:{self.namespace} - No more users connected to status monitoring, exiting thread..." ) with self.connections_lock: self.background_task = None
def monitor_alerts(self, user_info): sid = user_info['sid'] q = CommsQueue('alerts', private=True) try: for msg in q.listen(): if sid not in self.connections: break alert = msg['msg'] msg_type = msg['msg_type'] if classification.is_accessible( user_info['classification'], alert.get('classification', classification.UNRESTRICTED)): self.socketio.emit(msg_type, alert, room=sid, namespace=self.namespace) LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - " f"Sending {msg_type} event for alert matching ID: {alert['alert_id']}" ) if AUDIT: AUDIT_LOG.info( f"{user_info['uname']} [{user_info['classification']}]" f" :: AlertMonitoringNamespace.get_alert(alert_id={alert['alert_id']})" ) except Exception: LOGGER.exception( f"SocketIO:{self.namespace} - {user_info['display']}") finally: LOGGER.info( f"SocketIO:{self.namespace} - {user_info['display']} - Connection to client was terminated" )
def test_submission_ingested(datastore, client): submission_queue = CommsQueue('submissions', private=True) test_res_array = [] completed = random_model_obj(SubmissionMessage).as_primitives() completed['msg_type'] = "SubmissionCompleted" ingested = random_model_obj(SubmissionMessage).as_primitives() ingested['msg_type'] = "SubmissionIngested" received = random_model_obj(SubmissionMessage).as_primitives() received['msg_type'] = "SubmissionReceived" started = random_model_obj(SubmissionMessage).as_primitives() started['msg_type'] = "SubmissionStarted" def completed_callback(data): test_res_array.append(('completed', completed['msg'] == data)) def ingested_callback(data): test_res_array.append(('ingested', ingested['msg'] == data)) def received_callback(data): test_res_array.append(('received', received['msg'] == data)) def started_callback(data): test_res_array.append(('started', started['msg'] == data)) def publish_thread(): time.sleep(1) submission_queue.publish(completed) submission_queue.publish(ingested) submission_queue.publish(received) submission_queue.publish(started) threading.Thread(target=publish_thread).start() client.socketio.listen_on_submissions( completed_callback=completed_callback, ingested_callback=ingested_callback, received_callback=received_callback, started_callback=started_callback, timeout=2) assert len(test_res_array) == 4 for test, result in test_res_array: if not result: pytest.fail("{} failed.".format(test))
def save_alert(datastore, counter, logger, alert, psid): if psid: msg_type = "AlertUpdated" perform_alert_update(datastore, logger, alert) counter.increment('updated') ret_val = 'update' else: msg_type = "AlertCreated" datastore.alert.save(alert['alert_id'], alert) logger.info(f"Alert {alert['alert_id']} has been created.") counter.increment('created') ret_val = 'create' msg = AlertMessage({ "msg": alert, "msg_type": msg_type, "sender": "alerter" }) CommsQueue('alerts').publish(msg.as_primitives()) return ret_val
def _test_message_through_queue(queue_name, test_message, redis): t = Thread(target=publish_message, args=(queue_name, test_message, redis)) try: t.start() with CommsQueue(queue_name) as cq: for msg in cq.listen(): loader_path = msg.get('msg_loader', None) if loader_path is None: raise ValueError( "Message does not have a message loader class path.") msg_obj = load_module_by_path(loader_path) obj = msg_obj(msg) assert obj == test_message break finally: t.join() assert not t.is_alive()
def publish_messages(message_list): time.sleep(0.1) with CommsQueue('test-comms-queue') as cq_p: for message in message_list: cq_p.publish(message)
class HeartbeatFormatter(object): def __init__(self, sender, log, config=None, redis=None): self.sender = sender self.log = log self.config = config or forge.get_config() self.datastore = forge.get_datastore(self.config) self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.redis_persist = get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) self.status_queue = CommsQueue(STATUS_QUEUE, self.redis) self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH, self.redis_persist) self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis) self.ingest_scanning = Hash('m-scanning-table', self.redis_persist) self.ingest_unique_queue = PriorityQueue('m-unique', self.redis_persist) self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) constants = forge.get_constants(self.config) self.c_rng = constants.PRIORITY_RANGES['critical'] self.h_rng = constants.PRIORITY_RANGES['high'] self.m_rng = constants.PRIORITY_RANGES['medium'] self.l_rng = constants.PRIORITY_RANGES['low'] self.c_s_at = self.config.core.ingester.sampling_at['critical'] self.h_s_at = self.config.core.ingester.sampling_at['high'] self.m_s_at = self.config.core.ingester.sampling_at['medium'] self.l_s_at = self.config.core.ingester.sampling_at['low'] self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS} if self.config.core.expiry.batch_delete: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}/DAY]" else: self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \ f"{self.datastore.ds.hour}]" self.scheduler = BackgroundScheduler(daemon=True) self.scheduler.add_job( self._reload_expiry_queues, 'interval', seconds=self.config.core.metrics.export_interval * 4) self.scheduler.start() def _reload_expiry_queues(self): try: self.log.info("Refreshing expiry queues...") for collection_name in metrics.EXPIRY_METRICS: try: collection = getattr(self.datastore, collection_name) self.to_expire[collection_name] = collection.search( self.delete_query, rows=0, fl='id', track_total_hits="true")['total'] except SearchException: self.to_expire[collection_name] = 0 except Exception: self.log.exception( "Unknown exception occurred while reloading expiry queues:") def send_heartbeat(self, m_type, m_name, m_data, instances): if m_type == "dispatcher": try: instances = sorted(Dispatcher.all_instances( self.redis_persist)) inflight = { _i: Dispatcher.instance_assignment_size( self.redis_persist, _i) for _i in instances } queues = { _i: Dispatcher.all_queue_lengths(self.redis, _i) for _i in instances } msg = { "sender": self.sender, "msg": { "inflight": { "max": self.config.core.dispatcher.max_inflight, "outstanding": self.dispatch_active_hash.length(), "per_instance": [inflight[_i] for _i in instances] }, "instances": len(instances), "metrics": m_data, "queues": { "ingest": self.dispatcher_submission_queue.length(), "start": [queues[_i]['start'] for _i in instances], "result": [queues[_i]['result'] for _i in instances], "command": [queues[_i]['command'] for _i in instances] }, "component": m_name, } } self.status_queue.publish( DispatcherMessage(msg).as_primitives()) self.log.info(f"Sent dispatcher heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating DispatcherMessage") elif m_type == "ingester": try: c_q_len = self.ingest_unique_queue.count(*self.c_rng) h_q_len = self.ingest_unique_queue.count(*self.h_rng) m_q_len = self.ingest_unique_queue.count(*self.m_rng) l_q_len = self.ingest_unique_queue.count(*self.l_rng) msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "processing": { "inflight": self.ingest_scanning.length() }, "processing_chance": { "critical": 1 - drop_chance(c_q_len, self.c_s_at), "high": 1 - drop_chance(h_q_len, self.h_s_at), "low": 1 - drop_chance(l_q_len, self.l_s_at), "medium": 1 - drop_chance(m_q_len, self.m_s_at) }, "queues": { "critical": c_q_len, "high": h_q_len, "ingest": self.ingest_queue.length(), "complete": self.ingest_complete_queue.length(), "low": l_q_len, "medium": m_q_len } } } self.status_queue.publish(IngestMessage(msg).as_primitives()) self.log.info(f"Sent ingester heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating IngestMessage") elif m_type == "alerter": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "queues": { "alert": self.alert_queue.length() } } } self.status_queue.publish(AlerterMessage(msg).as_primitives()) self.log.info(f"Sent alerter heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating AlerterMessage") elif m_type == "expiry": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, "queues": self.to_expire } } self.status_queue.publish(ExpiryMessage(msg).as_primitives()) self.log.info(f"Sent expiry heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ExpiryMessage") elif m_type == "archive": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data } } self.status_queue.publish(ArchiveMessage(msg).as_primitives()) self.log.info(f"Sent archive heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ArchiveMessage") elif m_type == "scaler": try: msg = { "sender": self.sender, "msg": { "instances": instances, "metrics": m_data, } } self.status_queue.publish(ScalerMessage(msg).as_primitives()) self.log.info(f"Sent scaler heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating WatcherMessage") elif m_type == "scaler_status": try: msg = { "sender": self.sender, "msg": { "service_name": m_name, "metrics": m_data, } } self.status_queue.publish( ScalerStatusMessage(msg).as_primitives()) self.log.info(f"Sent scaler status heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating WatcherMessage") elif m_type == "service": try: busy, idle = get_working_and_idle(self.redis, m_name) msg = { "sender": self.sender, "msg": { "instances": len(busy) + len(idle), "metrics": m_data, "activity": { 'busy': len(busy), 'idle': len(idle) }, "queue": get_service_queue(m_name, self.redis).length(), "service_name": m_name } } self.status_queue.publish(ServiceMessage(msg).as_primitives()) self.log.info(f"Sent service heartbeat: {msg['msg']}") except Exception: self.log.exception( "An exception occurred while generating ServiceMessage") else: self.log.warning( f"Skipping unknown counter: {m_name} [{m_type}] ==> {m_data}")
def __init__(self, datastore, logger, classification=None, redis=None, persistent_redis=None, metrics_name='ingester'): self.datastore = datastore self.log = logger # Cache the user groups self.cache_lock = threading.RLock( ) # TODO are middle man instances single threaded now? self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Create a config cache that will refresh config values periodically self.config = forge.CachedObject(forge.get_config) # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value = constants.PRIORITIES self.priority_range = constants.PRIORITY_RANGES self.threshold_value = constants.PRIORITY_THRESHOLDS # Connect to the redis servers self.redis = redis or get_client( host=self.config.core.redis.nonpersistent.host, port=self.config.core.redis.nonpersistent.port, private=False, ) self.persistent_redis = persistent_redis or get_client( host=self.config.core.redis.persistent.host, port=self.config.core.redis.persistent.port, private=False, ) # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.persistent_redis) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(_completeq_name, self.redis) # Internal. Dropped entries are placed on this queue. # self.drop_queue = NamedQueue('m-drop', self.persistent_redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.persistent_redis) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.persistent_redis) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.persistent_redis) # Internal, timeout watch queue self.timeout_queue = PriorityQueue('m-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.persistent_redis) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis)
def get_metrics_sink(redis=None): from assemblyline.remote.datatypes.queues.comms import CommsQueue return CommsQueue('assemblyline_metrics', host=redis)
#!/usr/bin/env python import sys from assemblyline.remote.datatypes.queues.comms import CommsQueue from pprint import pprint if __name__ == "__main__": queue_name = None if len(sys.argv) > 1: queue_name = sys.argv[1] if queue_name is None: print( "\nERROR: You must specify a queue name.\n\npubsub_reader.py [queue_name]" ) exit(1) print(f"Listening for messages on '{queue_name}' queue.") q = CommsQueue(queue_name) try: while True: for msg in q.listen(): pprint(msg) except KeyboardInterrupt: print('Exiting') finally: q.close()
def publish_message(queue_name, test_message, redis): time.sleep(0.1) with CommsQueue(queue_name, redis) as cq: cq.publish(test_message.as_primitives())
class HeartbeatManager(ServerBase): def __init__(self, config=None): super().__init__('assemblyline.heartbeat_manager') self.config = config or forge.get_config() self.datastore = forge.get_datastore() self.metrics_queue = CommsQueue(METRICS_QUEUE) self.scheduler = BackgroundScheduler(daemon=True) self.hm = HeartbeatFormatter("heartbeat_manager", self.log, config=self.config) self.counters_lock = Lock() self.counters = {} self.rolling_window = {} self.window_ttl = {} self.ttl = self.config.core.metrics.export_interval * 2 self.window_size = int(60 / self.config.core.metrics.export_interval) if self.window_size != 60 / self.config.core.metrics.export_interval: self.log.warning( "Cannot calculate a proper window size for reporting heartbeats. " "Metrics reported during hearbeat will be wrong.") if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="heartbeat_manager") else: self.apm_client = None def try_run(self): self.scheduler.add_job( self._export_hearbeats, 'interval', seconds=self.config.core.metrics.export_interval) self.scheduler.start() while self.running: for msg in self.metrics_queue.listen(): # APM Transaction start if self.apm_client: self.apm_client.begin_transaction('heartbeat') m_name = msg.pop('name', None) m_type = msg.pop('type', None) m_host = msg.pop('host', None) msg.pop('instance', None) self.log.debug(f"Received {m_type.upper()} metrics message") if not m_name or not m_type or not m_host: # APM Transaction end if self.apm_client: self.apm_client.end_transaction( 'process_message', 'invalid_message') continue with self.counters_lock: c_key = (m_name, m_type, m_host) if c_key not in self.counters or m_type in NON_AGGREGATED: self.counters[c_key] = Counter(msg) else: non_agg_values = {} if m_type in NON_AGGREGATED_COUNTERS: non_agg_values = { k: v for k, v in msg.items() if k in NON_AGGREGATED_COUNTERS[m_type] } self.counters[c_key].update(Counter(msg)) for k, v in non_agg_values.items(): self.counters[c_key][k] = v # APM Transaction end if self.apm_client: self.apm_client.end_transaction('process_message', 'success') def _export_hearbeats(self): try: self.heartbeat() self.log.info("Expiring unused counters...") # APM Transaction start if self.apm_client: self.apm_client.begin_transaction('heartbeat') c_time = time.time() for k in list(self.window_ttl.keys()): if self.window_ttl.get(k, c_time) < c_time: c_name, c_type, c_host = k self.log.info( f"Counter {c_name} [{c_type}] for host {c_host} is expired" ) del self.window_ttl[k] del self.rolling_window[k] self.log.info("Saving current counters to rolling window ...") with self.counters_lock: counter_copy, self.counters = self.counters, {} for w_key, counter in counter_copy.items(): _, m_type, _ = w_key if w_key not in self.rolling_window or m_type in NON_AGGREGATED: self.rolling_window[w_key] = [counter] else: self.rolling_window[w_key].append(counter) self.rolling_window[w_key] = self.rolling_window[w_key][ -self.window_size:] self.window_ttl[w_key] = time.time() + self.ttl self.log.info("Compiling service list...") aggregated_counters = {} for service in [ s['name'] for s in self.datastore.list_all_services(as_obj=False) if s['enabled'] ]: data = { 'cache_hit': 0, 'cache_miss': 0, 'cache_skipped': 0, 'execute': 0, 'fail_recoverable': 0, 'fail_nonrecoverable': 0, 'scored': 0, 'not_scored': 0, 'instances': 0 } aggregated_counters[(service, 'service')] = Counter(data) self.log.info("Aggregating heartbeat data...") for component_parts, counters_list in self.rolling_window.items(): c_name, c_type, c_host = component_parts # Expiring data outside of the window counters_list = counters_list[-self.window_size:] key = (c_name, c_type) if key not in aggregated_counters: aggregated_counters[key] = Counter() aggregated_counters[key]['instances'] += 1 for c in counters_list: aggregated_counters[key].update(c) self.log.info("Generating heartbeats...") for aggregated_parts, counter in aggregated_counters.items(): agg_c_name, agg_c_type = aggregated_parts with elasticapm.capture_span(name=f"{agg_c_type}.{agg_c_name}", span_type="send_heartbeat"): metrics_data = {} for key, value in counter.items(): # Skip counts, they will be paired with a time entry and we only want to count it once if key.endswith('.c'): continue # We have an entry that is a timer, should also have a .c count elif key.endswith('.t'): name = key.rstrip('.t') metrics_data[name] = value / max( counter.get(name + ".c", 1), 1) metrics_data[name + "_count"] = counter.get( name + ".c", 0) # Plain old metric, no modifications needed else: metrics_data[key] = value agg_c_instances = metrics_data.pop('instances', 1) metrics_data.pop('instances_count', None) self.hm.send_heartbeat(agg_c_type, agg_c_name, metrics_data, agg_c_instances) # APM Transaction end if self.apm_client: self.apm_client.end_transaction('send_heartbeats', 'success') except Exception: self.log.exception( "Unknown exception occurred during heartbeat creation:")
class MetricsServer(ServerBase): """ There can only be one of these type of metrics server running because it runs of a pubsub queue. """ def __init__(self, config=None): super().__init__('assemblyline.metrics_aggregator', shutdown_timeout=65) self.config = config or forge.get_config() self.elastic_hosts = self.config.core.metrics.elasticsearch.hosts self.is_datastream = False if not self.elastic_hosts: self.log.error( "No elasticsearch cluster defined to store metrics. All gathered stats will be ignored..." ) sys.exit(1) self.scheduler = BackgroundScheduler(daemon=True) self.metrics_queue = None self.es = None self.counters_lock = Lock() self.counters = {} if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="metrics_aggregator") else: self.apm_client = None def try_run(self): # If our connection to the metrics database requires a custom ca cert, prepare it ca_certs = None if self.config.core.metrics.elasticsearch.host_certificates: with tempfile.NamedTemporaryFile(delete=False) as ca_certs_file: ca_certs = ca_certs_file.name ca_certs_file.write(self.config.core.metrics.elasticsearch. host_certificates.encode()) self.metrics_queue = CommsQueue(METRICS_QUEUE) self.es = elasticsearch.Elasticsearch( hosts=self.elastic_hosts, connection_class=elasticsearch.RequestsHttpConnection, ca_certs=ca_certs) # Determine if ES will support data streams (>= 7.9) self.is_datastream = version.parse( self.es.info()['version']['number']) >= version.parse("7.9") self.scheduler.add_job(self._create_aggregated_metrics, 'interval', seconds=60) self.scheduler.start() while self.running: for msg in self.metrics_queue.listen(): # APM Transaction start if self.apm_client: self.apm_client.begin_transaction('metrics') m_name = msg.pop('name', None) m_type = msg.pop('type', None) msg.pop('host', None) msg.pop('instance', None) self.log.debug(f"Received {m_type.upper()} metrics message") if not m_name or not m_type: # APM Transaction end if self.apm_client: self.apm_client.end_transaction( 'process_message', 'invalid_message') continue with self.counters_lock: c_key = (m_name, m_type) if c_key not in self.counters or m_type in NON_AGGREGATED: self.counters[c_key] = Counter(msg) else: non_agg_values = {} if m_type in NON_AGGREGATED_COUNTERS: non_agg_values = { k: v for k, v in msg.items() if k in NON_AGGREGATED_COUNTERS[m_type] } self.counters[c_key].update(Counter(msg)) for k, v in non_agg_values.items(): self.counters[c_key][k] = v # APM Transaction end if self.apm_client: self.apm_client.end_transaction('process_message', 'success') def _create_aggregated_metrics(self): self.log.info("Copying counters ...") # APM Transaction start if self.apm_client: self.apm_client.begin_transaction('metrics') with self.counters_lock: counter_copy, self.counters = self.counters, {} self.log.info("Aggregating metrics ...") timestamp = now_as_iso() for component, counts in counter_copy.items(): component_name, component_type = component output_metrics = {'name': component_name, 'type': component_type} for key, value in counts.items(): # Skip counts, they will be paired with a time entry and we only want to count it once if key.endswith('.c'): continue # We have an entry that is a timer, should also have a .c count elif key.endswith('.t'): name = key.rstrip('.t') output_metrics[name] = counts[key] / counts.get( name + ".c", 1) output_metrics[name + "_count"] = counts.get( name + ".c", 0) # Plain old metric, no modifications needed else: output_metrics[key] = value ensure_indexes(self.log, self.es, self.config.core.metrics.elasticsearch, [component_type], datastream_enabled=self.is_datastream) index = f"al_metrics_{component_type}" # Were data streams created for the index specified? try: if self.es.indices.get_index_template(name=f"{index}_ds"): output_metrics['@timestamp'] = timestamp index = f"{index}_ds" except elasticsearch.exceptions.TransportError: pass output_metrics['timestamp'] = timestamp output_metrics = cleanup_metrics(output_metrics) self.log.info(output_metrics) with_retries(self.log, self.es.index, index=index, body=output_metrics) self.log.info("Metrics aggregated. Waiting for next run...") # APM Transaction end if self.apm_client: self.apm_client.end_transaction('aggregate_metrics', 'success')
def get_submission_traffic_channel(): return CommsQueue('submissions', host=config.core.redis.nonpersistent.host, port=config.core.redis.nonpersistent.port)
class Ingester(ThreadedCoreBase): def __init__(self, datastore=None, logger=None, classification=None, redis=None, persistent_redis=None, metrics_name='ingester', config=None): super().__init__('assemblyline.ingester', logger, redis=redis, redis_persist=persistent_redis, datastore=datastore, config=config) # Cache the user groups self.cache_lock = threading.RLock() self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS self.cache = {} self.notification_queues = {} self.whitelisted = {} self.whitelisted_lock = threading.RLock() # Module path parameters are fixed at start time. Changing these involves a restart self.is_low_priority = load_module_by_path( self.config.core.ingester.is_low_priority) self.get_whitelist_verdict = load_module_by_path( self.config.core.ingester.get_whitelist_verdict) self.whitelist = load_module_by_path( self.config.core.ingester.whitelist) # Constants are loaded based on a non-constant path, so has to be done at init rather than load constants = forge.get_constants(self.config) self.priority_value: dict[str, int] = constants.PRIORITIES self.priority_range: dict[str, Tuple[int, int]] = constants.PRIORITY_RANGES self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS # Classification engine self.ce = classification or forge.get_classification() # Metrics gathering factory self.counter = MetricsFactory(metrics_type='ingester', schema=Metrics, redis=self.redis, config=self.config, name=metrics_name) # State. The submissions in progress are stored in Redis in order to # persist this state and recover in case we crash. self.scanning = Hash('m-scanning-table', self.redis_persist) # Input. The dispatcher creates a record when any submission completes. self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis) # Input. An external process places submission requests on this queue. self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist) # Output. Duplicate our input traffic into this queue so it may be cloned by other systems self.traffic_queue = CommsQueue('submissions', self.redis) # Internal. Unique requests are placed in and processed from this queue. self.unique_queue = PriorityQueue('m-unique', self.redis_persist) # Internal, delay queue for retrying self.retry_queue = PriorityQueue('m-retry', self.redis_persist) # Internal, timeout watch queue self.timeout_queue: PriorityQueue[str] = PriorityQueue( 'm-timeout', self.redis) # Internal, queue for processing duplicates # When a duplicate file is detected (same cache key => same file, and same # submission parameters) the file won't be ingested normally, but instead a reference # will be written to a duplicate queue. Whenever a file is finished, in the complete # method, not only is the original ingestion finalized, but all entries in the duplicate queue # are finalized as well. This has the effect that all concurrent ingestion of the same file # are 'merged' into a single submission to the system. self.duplicate_queue = MultiQueue(self.redis_persist) # Output. submissions that should have alerts generated self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist) # Utility object to help submit tasks to dispatching self.submit_client = SubmissionClient(datastore=self.datastore, redis=self.redis) if self.config.core.metrics.apm_server.server_url is not None: self.log.info( f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}" ) elasticapm.instrument() self.apm_client = elasticapm.Client( server_url=self.config.core.metrics.apm_server.server_url, service_name="ingester") else: self.apm_client = None def try_run(self): threads_to_maintain = { 'Retries': self.handle_retries, 'Timeouts': self.handle_timeouts } threads_to_maintain.update({ f'Complete_{n}': self.handle_complete for n in range(COMPLETE_THREADS) }) threads_to_maintain.update( {f'Ingest_{n}': self.handle_ingest for n in range(INGEST_THREADS)}) threads_to_maintain.update( {f'Submit_{n}': self.handle_submit for n in range(SUBMIT_THREADS)}) self.maintain_threads(threads_to_maintain) def handle_ingest(self): cpu_mark = time.process_time() time_mark = time.time() # Move from ingest to unique and waiting queues. # While there are entries in the ingest queue we consume chunk_size # entries at a time and move unique entries to uniqueq / queued and # duplicates to their own queues / waiting. while self.running: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) message = self.ingest_queue.pop(timeout=1) cpu_mark = time.process_time() time_mark = time.time() if not message: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') try: if 'submission' in message: # A retried task task = IngestTask(message) else: # A new submission sub = MessageSubmission(message) task = IngestTask(dict( submission=sub, ingest_id=sub.sid, )) task.submission.sid = None # Reset to new random uuid # Write all input to the traffic queue self.traffic_queue.publish( SubmissionMessage({ 'msg': sub, 'msg_type': 'SubmissionIngested', 'sender': 'ingester', }).as_primitives()) except (ValueError, TypeError) as error: self.counter.increment('error') self.log.exception( f"Dropped ingest submission {message} because {str(error)}" ) # End of ingest message (value_error) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'value_error') continue self.ingest(task) # End of ingest message (success) if self.apm_client: self.apm_client.end_transaction('ingest_input', 'success') def handle_submit(self): time_mark, cpu_mark = time.time(), time.process_time() while self.running: # noinspection PyBroadException try: self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) # Check if there is room for more submissions length = self.scanning.length() if length >= self.config.core.ingester.max_inflight: self.sleep(0.1) time_mark, cpu_mark = time.time(), time.process_time() continue raw = self.unique_queue.blocking_pop(timeout=3) time_mark, cpu_mark = time.time(), time.process_time() if not raw: continue # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') task = IngestTask(raw) # Check if we need to drop a file for capacity reasons, but only if the # number of files in flight is alreay over 80% if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop( task): # End of ingest message (dropped) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'dropped') continue if self.is_whitelisted(task): # End of ingest message (whitelisted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'whitelisted') continue # Check if this file has been previously processed. pprevious, previous, score, scan_key = None, None, None, None if not task.submission.params.ignore_cache: pprevious, previous, score, scan_key = self.check(task) else: scan_key = self.stamp_filescore_key(task) # If it HAS been previously processed, we are dealing with a resubmission # finalize will decide what to do, and put the task back in the queue # rewritten properly if we are going to run it again if previous: if not task.submission.params.services.resubmit and not pprevious: self.log.warning( f"No psid for what looks like a resubmission of " f"{task.submission.files[0].sha256}: {scan_key}") self.finalize(pprevious, previous, score, task) # End of ingest message (finalized) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'finalized') continue # We have decided this file is worth processing # Add the task to the scanning table, this is atomic across all submit # workers, so if it fails, someone beat us to the punch, record the file # as a duplicate then. if not self.scanning.add(scan_key, task.as_primitives()): self.log.debug('Duplicate %s', task.submission.files[0].sha256) self.counter.increment('duplicates') self.duplicate_queue.push(_dup_prefix + scan_key, task.as_primitives()) # End of ingest message (duplicate) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'duplicate') continue # We have managed to add the task to the scan table, so now we go # ahead with the submission process try: self.submit(task) # End of ingest message (submitted) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'submitted') continue except Exception as _ex: # For some reason (contained in `ex`) we have failed the submission # The rest of this function is error handling/recovery ex = _ex # traceback = _ex.__traceback__ self.counter.increment('error') should_retry = True if isinstance(ex, CorruptedFileStoreException): self.log.error( "Submission for file '%s' failed due to corrupted " "filestore: %s" % (task.sha256, str(ex))) should_retry = False elif isinstance(ex, DataStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed due to " "data store error:\n%s" % (task.sha256, trace)) elif not isinstance(ex, FileStoreException): trace = exceptions.get_stacktrace_info(ex) self.log.error("Submission for file '%s' failed: %s" % (task.sha256, trace)) task = IngestTask(self.scanning.pop(scan_key)) if not task: self.log.error('No scanning entry for for %s', task.sha256) # End of ingest message (no_scan_entry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'no_scan_entry') continue if not should_retry: # End of ingest message (cannot_retry) if self.apm_client: self.apm_client.end_transaction( 'ingest_submit', 'cannot_retry') continue self.retry(task, scan_key, ex) # End of ingest message (retry) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'retried') except Exception: self.log.exception("Unexpected error") # End of ingest message (exception) if self.apm_client: self.apm_client.end_transaction('ingest_submit', 'exception') def handle_complete(self): while self.running: result = self.complete_queue.pop(timeout=3) if not result: continue cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_msg') sub = DatabaseSubmission(result) self.completed(sub) # End of ingest message (success) if self.apm_client: elasticapm.label(sid=sub.sid) self.apm_client.end_transaction('ingest_complete', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_retries(self): tasks = [] while self.sleep(0 if tasks else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_retries') tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(), num=100) for task in tasks: self.ingest_queue.push(task) # End of ingest message (success) if self.apm_client: elasticapm.label(retries=len(tasks)) self.apm_client.end_transaction('ingest_retries', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def handle_timeouts(self): timeouts = [] while self.sleep(0 if timeouts else 3): cpu_mark = time.process_time() time_mark = time.time() # Start of ingest message if self.apm_client: self.apm_client.begin_transaction('ingest_timeouts') timeouts = self.timeout_queue.dequeue_range( upper_limit=isotime.now(), num=100) for scan_key in timeouts: # noinspection PyBroadException try: actual_timeout = False # Remove the entry from the hash of submissions in progress. entry = self.scanning.pop(scan_key) if entry: actual_timeout = True self.log.error("Submission timed out for %s: %s", scan_key, str(entry)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if dup: actual_timeout = True while dup: self.log.error("Submission timed out for %s: %s", scan_key, str(dup)) dup = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if actual_timeout: self.counter.increment('timed_out') except Exception: self.log.exception("Problem timing out %s:", scan_key) # End of ingest message (success) if self.apm_client: elasticapm.label(timeouts=len(timeouts)) self.apm_client.end_transaction('ingest_timeouts', 'success') self.counter.increment_execution_time( 'cpu_seconds', time.process_time() - cpu_mark) self.counter.increment_execution_time('busy_seconds', time.time() - time_mark) def get_groups_from_user(self, username: str) -> List[str]: # Reset the group cache at the top of each hour if time.time() // HOUR_IN_SECONDS > self._user_groups_reset: self._user_groups = {} self._user_groups_reset = time.time() // HOUR_IN_SECONDS # Get the groups for this user if not known if username not in self._user_groups: user_data = self.datastore.user.get(username) if user_data: self._user_groups[username] = user_data.groups else: self._user_groups[username] = [] return self._user_groups[username] def ingest(self, task: IngestTask): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Task received for processing" ) # Load a snapshot of ingest parameters as of right now. max_file_size = self.config.submission.max_file_size param = task.params self.counter.increment('bytes_ingested', increment_by=task.file_size) self.counter.increment('submissions_ingested') if any(len(file.sha256) != 64 for file in task.submission.files): self.log.error( f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped") self.send_notification(task, failure="Invalid sha256", logfunc=self.log.warning) return # Clean up metadata strings, since we may delete some, iterate on a copy of the keys for key in list(task.submission.metadata.keys()): value = task.submission.metadata[key] meta_size = len(value) if meta_size > self.config.submission.max_metadata_length: self.log.info( f'[{task.ingest_id} :: {task.sha256}] ' f'Removing {key} from metadata because value is too big') task.submission.metadata.pop(key) if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop: task.failure = f"File too large ({task.file_size} > {max_file_size})" self._notify_drop(task) self.counter.increment('skipped') self.log.error( f"[{task.ingest_id} :: {task.sha256}] {task.failure}") return # Set the groups from the user, if they aren't already set if not task.params.groups: task.params.groups = self.get_groups_from_user( task.params.submitter) # Check if this file is already being processed self.stamp_filescore_key(task) pprevious, previous, score = None, None, None if not param.ignore_cache: pprevious, previous, score, _ = self.check(task, count_miss=False) # Assign priority. low_priority = self.is_low_priority(task) priority = param.priority if priority < 0: priority = self.priority_value['medium'] if score is not None: priority = self.priority_value['low'] for level, threshold in self.threshold_value.items(): if score >= threshold: priority = self.priority_value[level] break elif low_priority: priority = self.priority_value['low'] # Reduce the priority by an order of magnitude for very old files. current_time = now() if priority and self.expired( current_time - task.submission.time.timestamp(), 0): priority = (priority / 10) or 1 param.priority = priority # Do this after priority has been assigned. # (So we don't end up dropping the resubmission). if previous: self.counter.increment('duplicates') self.finalize(pprevious, previous, score, task) # On cache hits of any kind we want to send out a completed message self.traffic_queue.publish( SubmissionMessage({ 'msg': task.submission, 'msg_type': 'SubmissionCompleted', 'sender': 'ingester', }).as_primitives()) return if self.drop(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped") return if self.is_whitelisted(task): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted") return self.unique_queue.push(priority, task.as_primitives()) def check( self, task: IngestTask, count_miss=True ) -> Tuple[Optional[str], Optional[str], Optional[float], str]: key = self.stamp_filescore_key(task) with self.cache_lock: result = self.cache.get(key, None) if result: self.counter.increment('cache_hit_local') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Local cache hit') else: result = self.datastore.filescore.get_if_exists(key) if result: self.counter.increment('cache_hit') self.log.info( f'[{task.ingest_id} :: {task.sha256}] Remote cache hit') else: if count_miss: self.counter.increment('cache_miss') return None, None, None, key with self.cache_lock: self.cache[key] = result current_time = now() age = current_time - result.time errors = result.errors if self.expired(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired" ) self.counter.increment('cache_expired') self.cache.pop(key, None) self.datastore.filescore.delete(key) return None, None, None, key elif self.stale(age, errors): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale" ) self.counter.increment('cache_stale') return None, None, result.score, key return result.psid, result.sid, result.score, key def stop(self): super().stop() if self.apm_client: elasticapm.uninstrument() self.submit_client.stop() def stale(self, delta: float, errors: int): if errors: return delta >= self.config.core.ingester.incomplete_stale_after_seconds else: return delta >= self.config.core.ingester.stale_after_seconds @staticmethod def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str: if not sha256: sha256 = task.submission.files[0].sha256 key = task.submission.scan_key if not key: key = task.params.create_filescore_key(sha256) task.submission.scan_key = key return key def completed(self, sub: DatabaseSubmission): """Invoked when notified that a submission has completed.""" # There is only one file in the submissions we have made sha256 = sub.files[0].sha256 scan_key = sub.scan_key if not scan_key: self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission missing scan key") scan_key = sub.params.create_filescore_key(sha256) raw = self.scanning.pop(scan_key) psid = sub.params.psid score = sub.max_score sid = sub.sid if not raw: # Some other worker has already popped the scanning queue? self.log.warning( f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] " f"Submission completed twice") return scan_key task = IngestTask(raw) task.submission.sid = sid errors = sub.error_count file_count = sub.file_count self.counter.increment('submissions_completed') self.counter.increment('files_completed', increment_by=file_count) self.counter.increment('bytes_completed', increment_by=task.file_size) with self.cache_lock: fs = self.cache[scan_key] = FileScore({ 'expiry_ts': now(self.config.core.ingester.cache_dtl * 24 * 60 * 60), 'errors': errors, 'psid': psid, 'score': score, 'sid': sid, 'time': now(), }) self.datastore.filescore.save(scan_key, fs) self.finalize(psid, sid, score, task) def exhaust() -> Iterable[IngestTask]: while True: res = self.duplicate_queue.pop(_dup_prefix + scan_key, blocking=False) if res is None: break res = IngestTask(res) res.submission.sid = sid yield res # You may be tempted to remove the assignment to dups and use the # value directly in the for loop below. That would be a mistake. # The function finalize may push on the duplicate queue which we # are pulling off and so condensing those two lines creates a # potential infinite loop. dups = [dup for dup in exhaust()] for dup in dups: self.finalize(psid, sid, score, dup) return scan_key def send_notification(self, task: IngestTask, failure=None, logfunc=None): if logfunc is None: logfunc = self.log.info if failure: task.failure = failure failure = task.failure if failure: logfunc("%s: %s", failure, str(task.json())) if not task.submission.notification.queue: return note_queue = _notification_queue_prefix + task.submission.notification.queue threshold = task.submission.notification.threshold if threshold is not None and task.score is not None and task.score < threshold: return q = self.notification_queues.get(note_queue, None) if not q: self.notification_queues[note_queue] = q = NamedQueue( note_queue, self.redis_persist) q.push(task.as_primitives()) def expired(self, delta: float, errors) -> bool: if errors: return delta >= self.config.core.ingester.incomplete_expire_after_seconds else: return delta >= self.config.core.ingester.expire_after def drop(self, task: IngestTask) -> bool: priority = task.params.priority sample_threshold = self.config.core.ingester.sampling_at dropped = False if priority <= _min_priority: dropped = True else: for level, rng in self.priority_range.items(): if rng[0] <= priority <= rng[1] and level in sample_threshold: dropped = must_drop(self.unique_queue.count(*rng), sample_threshold[level]) break if not dropped: if task.file_size > self.config.submission.max_file_size or task.file_size == 0: dropped = True if task.params.never_drop or not dropped: return False task.failure = 'Skipped' self._notify_drop(task) self.counter.increment('skipped') return True def _notify_drop(self, task: IngestTask): self.send_notification(task) c12n = task.params.classification expiry = now_as_iso(86400) sha256 = task.submission.files[0].sha256 self.datastore.save_or_freshen_file(sha256, {'sha256': sha256}, expiry, c12n, redis=self.redis) def is_whitelisted(self, task: IngestTask): reason, hit = self.get_whitelist_verdict(self.whitelist, task) hit = {x: dotdump(safe_str(y)) for x, y in hit.items()} sha256 = task.submission.files[0].sha256 if not reason: with self.whitelisted_lock: reason = self.whitelisted.get(sha256, None) if reason: hit = 'cached' if reason: if hit != 'cached': with self.whitelisted_lock: self.whitelisted[sha256] = reason task.failure = "Whitelisting due to reason %s (%s)" % (dotdump( safe_str(reason)), hit) self._notify_drop(task) self.counter.increment('whitelisted') return reason def submit(self, task: IngestTask): self.submit_client.submit( submission_obj=task.submission, completed_queue=COMPLETE_QUEUE_NAME, ) self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key) self.log.info( f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis" ) def retry(self, task: IngestTask, scan_key: str, ex): current_time = now() retries = task.retries + 1 if retries > _max_retries: trace = '' if ex: trace = ': ' + get_stacktrace_info(ex) self.log.error( f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}' ) self.duplicate_queue.delete(_dup_prefix + scan_key) elif self.expired(current_time - task.ingest_time.timestamp(), 0): self.log.info( f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission' ) self.duplicate_queue.delete(_dup_prefix + scan_key) else: self.log.info( f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})' ) task.retries = retries self.retry_queue.push(int(now(_retry_delay)), task.as_primitives()) def finalize(self, psid: str, sid: str, score: float, task: IngestTask): self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed") if psid: task.params.psid = psid task.score = score task.submission.sid = sid selected = task.params.services.selected resubmit_to = task.params.services.resubmit resubmit_selected = determine_resubmit_selected(selected, resubmit_to) will_resubmit = resubmit_selected and should_resubmit(score) if will_resubmit: task.extended_scan = 'submitted' task.params.psid = None if self.is_alert(task, score): self.log.info( f"[{task.ingest_id} :: {task.sha256}] Notifying alerter " f"to {'update' if task.params.psid else 'create'} an alert") self.alert_queue.push(task.as_primitives()) self.send_notification(task) if will_resubmit: self.log.info( f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis" ) task.params.psid = sid task.submission.sid = None task.submission.scan_key = None task.params.services.resubmit = [] task.params.services.selected = resubmit_selected self.unique_queue.push(task.params.priority, task.as_primitives()) def is_alert(self, task: IngestTask, score: float) -> bool: if not task.params.generate_alert: return False if score < self.config.core.alerter.threshold: return False return True
def test_status_namspace(datastore, sio): status_queue = CommsQueue('status', private=True) monitoring = get_random_id() alerter_hb_msg = random_model_obj(AlerterMessage).as_primitives() dispatcher_hb_msg = random_model_obj(DispatcherMessage).as_primitives() expiry_hb_msg = random_model_obj(ExpiryMessage).as_primitives() ingest_hb_msg = random_model_obj(IngestMessage).as_primitives() service_hb_msg = random_model_obj(ServiceMessage).as_primitives() test_res_array = [] @sio.on('monitoring', namespace='/status') def on_monitoring(data): # Confirmation that we are waiting for status messages test_res_array.append(('on_monitoring', data == monitoring)) @sio.on('AlerterHeartbeat', namespace='/status') def on_alerter_heartbeat(data): test_res_array.append( ('on_alerter_heartbeat', data == alerter_hb_msg['msg'])) @sio.on('DispatcherHeartbeat', namespace='/status') def on_dispatcher_heartbeat(data): test_res_array.append( ('on_dispatcher_heartbeat', data == dispatcher_hb_msg['msg'])) @sio.on('ExpiryHeartbeat', namespace='/status') def on_expiry_heartbeat(data): test_res_array.append( ('on_expiry_heartbeat', data == expiry_hb_msg['msg'])) @sio.on('IngestHeartbeat', namespace='/status') def on_ingest_heartbeat(data): test_res_array.append( ('on_ingest_heartbeat', data == ingest_hb_msg['msg'])) @sio.on('ServiceHeartbeat', namespace='/status') def on_service_heartbeat(data): test_res_array.append( ('on_service_heartbeat', data == service_hb_msg['msg'])) try: sio.emit('monitor', monitoring, namespace='/status') sio.sleep(1) status_queue.publish(alerter_hb_msg) status_queue.publish(dispatcher_hb_msg) status_queue.publish(expiry_hb_msg) status_queue.publish(ingest_hb_msg) status_queue.publish(service_hb_msg) start_time = time.time() while len(test_res_array) < 6 and time.time() - start_time < 5: sio.sleep(0.1) assert len(test_res_array) == 6 for test, result in test_res_array: if not result: pytest.fail(f"{test} failed.") finally: sio.disconnect()