def __init__(self, config=None):
        super().__init__('assemblyline.heartbeat_manager')
        self.config = config or forge.get_config()
        self.datastore = forge.get_datastore()
        self.metrics_queue = CommsQueue(METRICS_QUEUE)
        self.scheduler = BackgroundScheduler(daemon=True)
        self.hm = HeartbeatFormatter("heartbeat_manager",
                                     self.log,
                                     config=self.config)

        self.counters_lock = Lock()
        self.counters = {}
        self.rolling_window = {}
        self.window_ttl = {}
        self.ttl = self.config.core.metrics.export_interval * 2
        self.window_size = int(60 / self.config.core.metrics.export_interval)
        if self.window_size != 60 / self.config.core.metrics.export_interval:
            self.log.warning(
                "Cannot calculate a proper window size for reporting heartbeats. "
                "Metrics reported during hearbeat will be wrong.")

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="heartbeat_manager")
        else:
            self.apm_client = None
    def __init__(self, sender, log, config=None, redis=None):
        self.sender = sender
        self.log = log

        self.config = config or forge.get_config()
        self.datastore = forge.get_datastore(self.config)

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.redis_persist = get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )
        self.status_queue = CommsQueue(STATUS_QUEUE, self.redis)
        self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH,
                                         self.redis_persist)
        self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE,
                                                      self.redis)
        self.ingest_scanning = Hash('m-scanning-table', self.redis_persist)
        self.ingest_unique_queue = PriorityQueue('m-unique',
                                                 self.redis_persist)
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist)
        self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME,
                                                self.redis)
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist)

        constants = forge.get_constants(self.config)
        self.c_rng = constants.PRIORITY_RANGES['critical']
        self.h_rng = constants.PRIORITY_RANGES['high']
        self.m_rng = constants.PRIORITY_RANGES['medium']
        self.l_rng = constants.PRIORITY_RANGES['low']
        self.c_s_at = self.config.core.ingester.sampling_at['critical']
        self.h_s_at = self.config.core.ingester.sampling_at['high']
        self.m_s_at = self.config.core.ingester.sampling_at['medium']
        self.l_s_at = self.config.core.ingester.sampling_at['low']

        self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS}
        if self.config.core.expiry.batch_delete:
            self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \
                f"{self.datastore.ds.hour}/DAY]"
        else:
            self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \
                f"{self.datastore.ds.hour}]"

        self.scheduler = BackgroundScheduler(daemon=True)
        self.scheduler.add_job(
            self._reload_expiry_queues,
            'interval',
            seconds=self.config.core.metrics.export_interval * 4)
        self.scheduler.start()
Ejemplo n.º 3
0
def test_submission_namespace(datastore, sio):
    submission_queue = CommsQueue('submissions', private=True)
    monitoring = get_random_id()

    ingested = random_model_obj(SubmissionMessage).as_primitives()
    ingested['msg_type'] = "SubmissionIngested"
    received = random_model_obj(SubmissionMessage).as_primitives()
    received['msg_type'] = "SubmissionReceived"
    queued = random_model_obj(SubmissionMessage).as_primitives()
    queued['msg_type'] = "SubmissionQueued"
    started = random_model_obj(SubmissionMessage).as_primitives()
    started['msg_type'] = "SubmissionStarted"

    test_res_array = []

    @sio.on('monitoring', namespace='/submissions')
    def on_monitoring(data):
        # Confirmation that we are waiting for status messages
        test_res_array.append(('on_monitoring', data == monitoring))

    @sio.on('SubmissionIngested', namespace='/submissions')
    def on_submission_ingested(data):
        test_res_array.append(
            ('on_submission_ingested', data == ingested['msg']))

    @sio.on('SubmissionReceived', namespace='/submissions')
    def on_submission_received(data):
        test_res_array.append(
            ('on_submission_received', data == received['msg']))

    @sio.on('SubmissionQueued', namespace='/submissions')
    def on_submission_queued(data):
        test_res_array.append(('on_submission_queued', data == queued['msg']))

    @sio.on('SubmissionStarted', namespace='/submissions')
    def on_submission_started(data):
        test_res_array.append(
            ('on_submission_started', data == started['msg']))

    try:
        sio.emit('monitor', monitoring, namespace='/submissions')
        sio.sleep(1)

        submission_queue.publish(ingested)
        submission_queue.publish(received)
        submission_queue.publish(queued)
        submission_queue.publish(started)

        start_time = time.time()

        while len(test_res_array) < 5 and time.time() - start_time < 5:
            sio.sleep(0.1)

        assert len(test_res_array) == 5

        for test, result in test_res_array:
            if not result:
                pytest.fail(f"{test} failed.")
    finally:
        sio.disconnect()
Ejemplo n.º 4
0
def test_comms_queue(redis_connection):
    if redis_connection:
        from assemblyline.remote.datatypes.queues.comms import CommsQueue

        def publish_messages(message_list):
            time.sleep(0.1)
            with CommsQueue('test-comms-queue') as cq_p:
                for message in message_list:
                    cq_p.publish(message)

        msg_list = ["bob", 1, {"bob": 1}, [1, 2, 3], None, "Nice!", "stop"]
        t = Thread(target=publish_messages, args=(msg_list, ))
        t.start()

        with CommsQueue('test-comms-queue') as cq:
            x = 0
            for msg in cq.listen():
                if msg == "stop":
                    break

                assert msg == msg_list[x]

                x += 1

        t.join()
        assert not t.is_alive()
Ejemplo n.º 5
0
def test_alert_created(datastore, client):
    alert_queue = CommsQueue('alerts', private=True)

    created = random_model_obj(AlertMessage)
    created.msg_type = "AlertCreated"

    updated = random_model_obj(AlertMessage)
    updated.msg_type = "AlertUpdated"

    test_res_array = []

    def alerter_created_callback(data):
        test_res_array.append(('created', created['msg'] == data))

    def alerter_updated_callback(data):
        test_res_array.append(('updated', updated['msg'] == data))

    def publish_thread():
        time.sleep(1)
        alert_queue.publish(created.as_primitives())
        alert_queue.publish(updated.as_primitives())

    threading.Thread(target=publish_thread).start()
    client.socketio.listen_on_alerts_messages(
        alert_created_callback=alerter_created_callback,
        alert_updated_callback=alerter_updated_callback,
        timeout=2)
    assert len(test_res_array) == 2

    for test, result in test_res_array:
        if not result:
            pytest.fail("{} failed.".format(test))
Ejemplo n.º 6
0
def save_alert(datastore, counter, logger, alert, psid):
    def create_alert():
        msg_type = "AlertCreated"
        datastore.alert.save(alert['alert_id'], alert)
        logger.info(f"Alert {alert['alert_id']} has been created.")
        counter.increment('created')
        ret_val = 'create'
        return msg_type, ret_val

    if psid:
        try:
            msg_type = "AlertUpdated"
            perform_alert_update(datastore, logger, alert)
            counter.increment('updated')
            ret_val = 'update'
        except AlertMissingError as e:
            logger.info(
                f"{str(e)}. Creating a new alert [{alert['alert_id']}]...")
            msg_type, ret_val = create_alert()
    else:
        msg_type, ret_val = create_alert()

    msg = AlertMessage({
        "msg": alert,
        "msg_type": msg_type,
        "sender": "alerter"
    })
    CommsQueue('alerts').publish(msg.as_primitives())
    return ret_val
Ejemplo n.º 7
0
def test_alert_namespace(datastore, sio):
    alert_queue = CommsQueue('alerts', private=True)
    test_id = get_random_id()

    created = random_model_obj(AlertMessage)
    created.msg_type = "AlertCreated"

    updated = random_model_obj(AlertMessage)
    updated.msg_type = "AlertUpdated"

    test_res_array = []

    @sio.on('monitoring', namespace='/alerts')
    def on_monitoring(data):
        # Confirmation that we are waiting for alerts
        test_res_array.append(('on_monitoring', data == test_id))

    @sio.on('AlertCreated', namespace='/alerts')
    def on_alert_created(data):
        test_res_array.append(
            ('on_alert_created', data == created.as_primitives()['msg']))

    @sio.on('AlertUpdated', namespace='/alerts')
    def on_alert_updated(data):
        test_res_array.append(
            ('on_alert_updated', data == updated.as_primitives()['msg']))

    try:
        sio.emit('alert', test_id, namespace='/alerts')
        sio.sleep(1)

        alert_queue.publish(created.as_primitives())
        alert_queue.publish(updated.as_primitives())

        start_time = time.time()

        while len(test_res_array) < 3 or time.time() - start_time < 5:
            sio.sleep(0.1)

        assert len(test_res_array) == 3

        for test, result in test_res_array:
            if not result:
                pytest.fail(f"{test} failed.")

    finally:
        sio.disconnect()
    def try_run(self):
        # If our connection to the metrics database requires a custom ca cert, prepare it
        ca_certs = None
        if self.config.core.metrics.elasticsearch.host_certificates:
            with tempfile.NamedTemporaryFile(delete=False) as ca_certs_file:
                ca_certs = ca_certs_file.name
                ca_certs_file.write(self.config.core.metrics.elasticsearch.host_certificates.encode())

        self.metrics_queue = CommsQueue(METRICS_QUEUE)
        self.es = elasticsearch.Elasticsearch(hosts=self.elastic_hosts,
                                              connection_class=elasticsearch.RequestsHttpConnection,
                                              ca_certs=ca_certs)

        self.scheduler.add_job(self._create_aggregated_metrics, 'interval', seconds=60)
        self.scheduler.start()

        while self.running:
            for msg in self.metrics_queue.listen():
                # APM Transaction start
                if self.apm_client:
                    self.apm_client.begin_transaction('metrics')

                m_name = msg.pop('name', None)
                m_type = msg.pop('type', None)
                msg.pop('host', None)
                msg.pop('instance', None)

                self.log.debug(f"Received {m_type.upper()} metrics message")
                if not m_name or not m_type:
                    # APM Transaction end
                    if self.apm_client:
                        self.apm_client.end_transaction('process_message', 'invalid_message')

                    continue

                with self.counters_lock:
                    c_key = (m_name, m_type)
                    if c_key not in self.counters or m_type in NON_AGGREGATED:
                        self.counters[c_key] = Counter(msg)
                    else:
                        self.counters[c_key].update(Counter(msg))

                # APM Transaction end
                if self.apm_client:
                    self.apm_client.end_transaction('process_message', 'success')
Ejemplo n.º 9
0
def test_status_messages(datastore, client):
    status_queue = CommsQueue('status', private=True)
    test_res_array = []

    alerter_hb_msg = random_model_obj(AlerterMessage).as_primitives()
    dispatcher_hb_msg = random_model_obj(DispatcherMessage).as_primitives()
    expiry_hb_msg = random_model_obj(ExpiryMessage).as_primitives()
    ingest_hb_msg = random_model_obj(IngestMessage).as_primitives()
    service_hb_msg = random_model_obj(ServiceMessage).as_primitives()
    service_timing_msg = random_model_obj(ServiceTimingMessage).as_primitives()

    def alerter_callback(data):
        test_res_array.append(('alerter', alerter_hb_msg['msg'] == data))

    def dispatcher_callback(data):
        test_res_array.append(('dispatcher', dispatcher_hb_msg['msg'] == data))

    def expiry_callback(data):
        test_res_array.append(('expiry', expiry_hb_msg['msg'] == data))

    def ingest_callback(data):
        test_res_array.append(('ingest', ingest_hb_msg['msg'] == data))

    def service_callback(data):
        test_res_array.append(('service', service_hb_msg['msg'] == data))

    def service_timing_callback(data):
        test_res_array.append(
            ('service_timing', service_timing_msg['msg'] == data))

    def publish_thread():
        time.sleep(1)
        status_queue.publish(alerter_hb_msg)
        status_queue.publish(dispatcher_hb_msg)
        status_queue.publish(expiry_hb_msg)
        status_queue.publish(ingest_hb_msg)
        status_queue.publish(service_hb_msg)
        status_queue.publish(service_timing_msg)

    threading.Thread(target=publish_thread).start()
    client.socketio.listen_on_status_messages(
        alerter_msg_callback=alerter_callback,
        dispatcher_msg_callback=dispatcher_callback,
        expiry_msg_callback=expiry_callback,
        ingest_msg_callback=ingest_callback,
        service_msg_callback=service_callback,
        service_timing_msg_callback=service_timing_callback,
        timeout=2)
    assert len(test_res_array) == 6

    for test, result in test_res_array:
        if not result:
            pytest.fail("{} failed.".format(test))
Ejemplo n.º 10
0
    def monitor_system_status(self):
        q = CommsQueue('status', private=True)
        try:
            for msg in q.listen():
                if self.stop:
                    break

                message = msg['msg']
                msg_type = msg['msg_type']
                self.socketio.emit(msg_type, message, namespace=self.namespace)
                LOGGER.info(
                    f"SocketIO:{self.namespace} - Sending {msg_type} event to all connected users."
                )

        except Exception:
            LOGGER.exception(f"SocketIO:{self.namespace}")
        finally:
            LOGGER.info(
                f"SocketIO:{self.namespace} - No more users connected to status monitoring, exiting thread..."
            )
            with self.connections_lock:
                self.background_task = None
Ejemplo n.º 11
0
    def monitor_alerts(self, user_info):
        sid = user_info['sid']
        q = CommsQueue('alerts', private=True)
        try:
            for msg in q.listen():
                if sid not in self.connections:
                    break

                alert = msg['msg']
                msg_type = msg['msg_type']
                if classification.is_accessible(
                        user_info['classification'],
                        alert.get('classification',
                                  classification.UNRESTRICTED)):
                    self.socketio.emit(msg_type,
                                       alert,
                                       room=sid,
                                       namespace=self.namespace)
                    LOGGER.info(
                        f"SocketIO:{self.namespace} - {user_info['display']} - "
                        f"Sending {msg_type} event for alert matching ID: {alert['alert_id']}"
                    )

                    if AUDIT:
                        AUDIT_LOG.info(
                            f"{user_info['uname']} [{user_info['classification']}]"
                            f" :: AlertMonitoringNamespace.get_alert(alert_id={alert['alert_id']})"
                        )

        except Exception:
            LOGGER.exception(
                f"SocketIO:{self.namespace} - {user_info['display']}")
        finally:
            LOGGER.info(
                f"SocketIO:{self.namespace} - {user_info['display']} - Connection to client was terminated"
            )
Ejemplo n.º 12
0
def test_submission_ingested(datastore, client):
    submission_queue = CommsQueue('submissions', private=True)
    test_res_array = []

    completed = random_model_obj(SubmissionMessage).as_primitives()
    completed['msg_type'] = "SubmissionCompleted"
    ingested = random_model_obj(SubmissionMessage).as_primitives()
    ingested['msg_type'] = "SubmissionIngested"
    received = random_model_obj(SubmissionMessage).as_primitives()
    received['msg_type'] = "SubmissionReceived"
    started = random_model_obj(SubmissionMessage).as_primitives()
    started['msg_type'] = "SubmissionStarted"

    def completed_callback(data):
        test_res_array.append(('completed', completed['msg'] == data))

    def ingested_callback(data):
        test_res_array.append(('ingested', ingested['msg'] == data))

    def received_callback(data):
        test_res_array.append(('received', received['msg'] == data))

    def started_callback(data):
        test_res_array.append(('started', started['msg'] == data))

    def publish_thread():
        time.sleep(1)
        submission_queue.publish(completed)
        submission_queue.publish(ingested)
        submission_queue.publish(received)
        submission_queue.publish(started)

    threading.Thread(target=publish_thread).start()
    client.socketio.listen_on_submissions(
        completed_callback=completed_callback,
        ingested_callback=ingested_callback,
        received_callback=received_callback,
        started_callback=started_callback,
        timeout=2)

    assert len(test_res_array) == 4

    for test, result in test_res_array:
        if not result:
            pytest.fail("{} failed.".format(test))
Ejemplo n.º 13
0
def save_alert(datastore, counter, logger, alert, psid):
    if psid:
        msg_type = "AlertUpdated"
        perform_alert_update(datastore, logger, alert)
        counter.increment('updated')
        ret_val = 'update'
    else:
        msg_type = "AlertCreated"
        datastore.alert.save(alert['alert_id'], alert)
        logger.info(f"Alert {alert['alert_id']} has been created.")
        counter.increment('created')
        ret_val = 'create'

    msg = AlertMessage({
        "msg": alert,
        "msg_type": msg_type,
        "sender": "alerter"
    })
    CommsQueue('alerts').publish(msg.as_primitives())
    return ret_val
Ejemplo n.º 14
0
def _test_message_through_queue(queue_name, test_message, redis):
    t = Thread(target=publish_message, args=(queue_name, test_message, redis))

    try:
        t.start()

        with CommsQueue(queue_name) as cq:
            for msg in cq.listen():
                loader_path = msg.get('msg_loader', None)
                if loader_path is None:
                    raise ValueError(
                        "Message does not have a message loader class path.")

                msg_obj = load_module_by_path(loader_path)
                obj = msg_obj(msg)

                assert obj == test_message

                break

    finally:
        t.join()
        assert not t.is_alive()
Ejemplo n.º 15
0
 def publish_messages(message_list):
     time.sleep(0.1)
     with CommsQueue('test-comms-queue') as cq_p:
         for message in message_list:
             cq_p.publish(message)
class HeartbeatFormatter(object):
    def __init__(self, sender, log, config=None, redis=None):
        self.sender = sender
        self.log = log

        self.config = config or forge.get_config()
        self.datastore = forge.get_datastore(self.config)

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.redis_persist = get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )
        self.status_queue = CommsQueue(STATUS_QUEUE, self.redis)
        self.dispatch_active_hash = Hash(DISPATCH_TASK_HASH,
                                         self.redis_persist)
        self.dispatcher_submission_queue = NamedQueue(SUBMISSION_QUEUE,
                                                      self.redis)
        self.ingest_scanning = Hash('m-scanning-table', self.redis_persist)
        self.ingest_unique_queue = PriorityQueue('m-unique',
                                                 self.redis_persist)
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist)
        self.ingest_complete_queue = NamedQueue(COMPLETE_QUEUE_NAME,
                                                self.redis)
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist)

        constants = forge.get_constants(self.config)
        self.c_rng = constants.PRIORITY_RANGES['critical']
        self.h_rng = constants.PRIORITY_RANGES['high']
        self.m_rng = constants.PRIORITY_RANGES['medium']
        self.l_rng = constants.PRIORITY_RANGES['low']
        self.c_s_at = self.config.core.ingester.sampling_at['critical']
        self.h_s_at = self.config.core.ingester.sampling_at['high']
        self.m_s_at = self.config.core.ingester.sampling_at['medium']
        self.l_s_at = self.config.core.ingester.sampling_at['low']

        self.to_expire = {k: 0 for k in metrics.EXPIRY_METRICS}
        if self.config.core.expiry.batch_delete:
            self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \
                f"{self.datastore.ds.hour}/DAY]"
        else:
            self.delete_query = f"expiry_ts:[* TO {self.datastore.ds.now}-{self.config.core.expiry.delay}" \
                f"{self.datastore.ds.hour}]"

        self.scheduler = BackgroundScheduler(daemon=True)
        self.scheduler.add_job(
            self._reload_expiry_queues,
            'interval',
            seconds=self.config.core.metrics.export_interval * 4)
        self.scheduler.start()

    def _reload_expiry_queues(self):
        try:
            self.log.info("Refreshing expiry queues...")
            for collection_name in metrics.EXPIRY_METRICS:
                try:
                    collection = getattr(self.datastore, collection_name)
                    self.to_expire[collection_name] = collection.search(
                        self.delete_query,
                        rows=0,
                        fl='id',
                        track_total_hits="true")['total']
                except SearchException:
                    self.to_expire[collection_name] = 0
        except Exception:
            self.log.exception(
                "Unknown exception occurred while reloading expiry queues:")

    def send_heartbeat(self, m_type, m_name, m_data, instances):
        if m_type == "dispatcher":
            try:
                instances = sorted(Dispatcher.all_instances(
                    self.redis_persist))
                inflight = {
                    _i: Dispatcher.instance_assignment_size(
                        self.redis_persist, _i)
                    for _i in instances
                }
                queues = {
                    _i: Dispatcher.all_queue_lengths(self.redis, _i)
                    for _i in instances
                }

                msg = {
                    "sender": self.sender,
                    "msg": {
                        "inflight": {
                            "max": self.config.core.dispatcher.max_inflight,
                            "outstanding": self.dispatch_active_hash.length(),
                            "per_instance": [inflight[_i] for _i in instances]
                        },
                        "instances": len(instances),
                        "metrics": m_data,
                        "queues": {
                            "ingest":
                            self.dispatcher_submission_queue.length(),
                            "start": [queues[_i]['start'] for _i in instances],
                            "result":
                            [queues[_i]['result'] for _i in instances],
                            "command":
                            [queues[_i]['command'] for _i in instances]
                        },
                        "component": m_name,
                    }
                }
                self.status_queue.publish(
                    DispatcherMessage(msg).as_primitives())
                self.log.info(f"Sent dispatcher heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating DispatcherMessage")

        elif m_type == "ingester":
            try:
                c_q_len = self.ingest_unique_queue.count(*self.c_rng)
                h_q_len = self.ingest_unique_queue.count(*self.h_rng)
                m_q_len = self.ingest_unique_queue.count(*self.m_rng)
                l_q_len = self.ingest_unique_queue.count(*self.l_rng)

                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": instances,
                        "metrics": m_data,
                        "processing": {
                            "inflight": self.ingest_scanning.length()
                        },
                        "processing_chance": {
                            "critical": 1 - drop_chance(c_q_len, self.c_s_at),
                            "high": 1 - drop_chance(h_q_len, self.h_s_at),
                            "low": 1 - drop_chance(l_q_len, self.l_s_at),
                            "medium": 1 - drop_chance(m_q_len, self.m_s_at)
                        },
                        "queues": {
                            "critical": c_q_len,
                            "high": h_q_len,
                            "ingest": self.ingest_queue.length(),
                            "complete": self.ingest_complete_queue.length(),
                            "low": l_q_len,
                            "medium": m_q_len
                        }
                    }
                }
                self.status_queue.publish(IngestMessage(msg).as_primitives())
                self.log.info(f"Sent ingester heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating IngestMessage")

        elif m_type == "alerter":
            try:
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": instances,
                        "metrics": m_data,
                        "queues": {
                            "alert": self.alert_queue.length()
                        }
                    }
                }
                self.status_queue.publish(AlerterMessage(msg).as_primitives())
                self.log.info(f"Sent alerter heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating AlerterMessage")

        elif m_type == "expiry":
            try:
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": instances,
                        "metrics": m_data,
                        "queues": self.to_expire
                    }
                }
                self.status_queue.publish(ExpiryMessage(msg).as_primitives())
                self.log.info(f"Sent expiry heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating ExpiryMessage")

        elif m_type == "archive":
            try:
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": instances,
                        "metrics": m_data
                    }
                }
                self.status_queue.publish(ArchiveMessage(msg).as_primitives())
                self.log.info(f"Sent archive heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating ArchiveMessage")

        elif m_type == "scaler":
            try:
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": instances,
                        "metrics": m_data,
                    }
                }
                self.status_queue.publish(ScalerMessage(msg).as_primitives())
                self.log.info(f"Sent scaler heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating WatcherMessage")

        elif m_type == "scaler_status":
            try:
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "service_name": m_name,
                        "metrics": m_data,
                    }
                }
                self.status_queue.publish(
                    ScalerStatusMessage(msg).as_primitives())
                self.log.info(f"Sent scaler status heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating WatcherMessage")

        elif m_type == "service":
            try:
                busy, idle = get_working_and_idle(self.redis, m_name)
                msg = {
                    "sender": self.sender,
                    "msg": {
                        "instances": len(busy) + len(idle),
                        "metrics": m_data,
                        "activity": {
                            'busy': len(busy),
                            'idle': len(idle)
                        },
                        "queue": get_service_queue(m_name,
                                                   self.redis).length(),
                        "service_name": m_name
                    }
                }
                self.status_queue.publish(ServiceMessage(msg).as_primitives())
                self.log.info(f"Sent service heartbeat: {msg['msg']}")
            except Exception:
                self.log.exception(
                    "An exception occurred while generating ServiceMessage")

        else:
            self.log.warning(
                f"Skipping unknown counter: {m_name} [{m_type}] ==> {m_data}")
Ejemplo n.º 17
0
    def __init__(self,
                 datastore,
                 logger,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester'):
        self.datastore = datastore
        self.log = logger

        # Cache the user groups
        self.cache_lock = threading.RLock(
        )  # TODO are middle man instances single threaded now?
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Create a config cache that will refresh config values periodically
        self.config = forge.CachedObject(forge.get_config)

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value = constants.PRIORITIES
        self.priority_range = constants.PRIORITY_RANGES
        self.threshold_value = constants.PRIORITY_THRESHOLDS

        # Connect to the redis servers
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.persistent_redis = persistent_redis or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.persistent_redis)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(_completeq_name, self.redis)

        # Internal. Dropped entries are placed on this queue.
        # self.drop_queue = NamedQueue('m-drop', self.persistent_redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME,
                                       self.persistent_redis)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.persistent_redis)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.persistent_redis)

        # Internal, timeout watch queue
        self.timeout_queue = PriorityQueue('m-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.persistent_redis)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)
Ejemplo n.º 18
0
def get_metrics_sink(redis=None):
    from assemblyline.remote.datatypes.queues.comms import CommsQueue
    return CommsQueue('assemblyline_metrics', host=redis)
Ejemplo n.º 19
0
#!/usr/bin/env python

import sys

from assemblyline.remote.datatypes.queues.comms import CommsQueue
from pprint import pprint

if __name__ == "__main__":
    queue_name = None
    if len(sys.argv) > 1:
        queue_name = sys.argv[1]

    if queue_name is None:
        print(
            "\nERROR: You must specify a queue name.\n\npubsub_reader.py [queue_name]"
        )
        exit(1)

    print(f"Listening for messages on '{queue_name}' queue.")

    q = CommsQueue(queue_name)

    try:
        while True:
            for msg in q.listen():
                pprint(msg)
    except KeyboardInterrupt:
        print('Exiting')
    finally:
        q.close()
Ejemplo n.º 20
0
def publish_message(queue_name, test_message, redis):
    time.sleep(0.1)
    with CommsQueue(queue_name, redis) as cq:
        cq.publish(test_message.as_primitives())
class HeartbeatManager(ServerBase):
    def __init__(self, config=None):
        super().__init__('assemblyline.heartbeat_manager')
        self.config = config or forge.get_config()
        self.datastore = forge.get_datastore()
        self.metrics_queue = CommsQueue(METRICS_QUEUE)
        self.scheduler = BackgroundScheduler(daemon=True)
        self.hm = HeartbeatFormatter("heartbeat_manager",
                                     self.log,
                                     config=self.config)

        self.counters_lock = Lock()
        self.counters = {}
        self.rolling_window = {}
        self.window_ttl = {}
        self.ttl = self.config.core.metrics.export_interval * 2
        self.window_size = int(60 / self.config.core.metrics.export_interval)
        if self.window_size != 60 / self.config.core.metrics.export_interval:
            self.log.warning(
                "Cannot calculate a proper window size for reporting heartbeats. "
                "Metrics reported during hearbeat will be wrong.")

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="heartbeat_manager")
        else:
            self.apm_client = None

    def try_run(self):
        self.scheduler.add_job(
            self._export_hearbeats,
            'interval',
            seconds=self.config.core.metrics.export_interval)
        self.scheduler.start()

        while self.running:
            for msg in self.metrics_queue.listen():
                # APM Transaction start
                if self.apm_client:
                    self.apm_client.begin_transaction('heartbeat')

                m_name = msg.pop('name', None)
                m_type = msg.pop('type', None)
                m_host = msg.pop('host', None)
                msg.pop('instance', None)

                self.log.debug(f"Received {m_type.upper()} metrics message")
                if not m_name or not m_type or not m_host:
                    # APM Transaction end
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'process_message', 'invalid_message')

                    continue

                with self.counters_lock:
                    c_key = (m_name, m_type, m_host)
                    if c_key not in self.counters or m_type in NON_AGGREGATED:
                        self.counters[c_key] = Counter(msg)
                    else:
                        non_agg_values = {}
                        if m_type in NON_AGGREGATED_COUNTERS:
                            non_agg_values = {
                                k: v
                                for k, v in msg.items()
                                if k in NON_AGGREGATED_COUNTERS[m_type]
                            }
                        self.counters[c_key].update(Counter(msg))
                        for k, v in non_agg_values.items():
                            self.counters[c_key][k] = v

                # APM Transaction end
                if self.apm_client:
                    self.apm_client.end_transaction('process_message',
                                                    'success')

    def _export_hearbeats(self):
        try:
            self.heartbeat()
            self.log.info("Expiring unused counters...")
            # APM Transaction start
            if self.apm_client:
                self.apm_client.begin_transaction('heartbeat')

            c_time = time.time()
            for k in list(self.window_ttl.keys()):
                if self.window_ttl.get(k, c_time) < c_time:
                    c_name, c_type, c_host = k
                    self.log.info(
                        f"Counter {c_name} [{c_type}] for host {c_host} is expired"
                    )
                    del self.window_ttl[k]
                    del self.rolling_window[k]

            self.log.info("Saving current counters to rolling window ...")
            with self.counters_lock:
                counter_copy, self.counters = self.counters, {}

            for w_key, counter in counter_copy.items():
                _, m_type, _ = w_key
                if w_key not in self.rolling_window or m_type in NON_AGGREGATED:
                    self.rolling_window[w_key] = [counter]
                else:
                    self.rolling_window[w_key].append(counter)

                self.rolling_window[w_key] = self.rolling_window[w_key][
                    -self.window_size:]
                self.window_ttl[w_key] = time.time() + self.ttl

            self.log.info("Compiling service list...")
            aggregated_counters = {}
            for service in [
                    s['name']
                    for s in self.datastore.list_all_services(as_obj=False)
                    if s['enabled']
            ]:
                data = {
                    'cache_hit': 0,
                    'cache_miss': 0,
                    'cache_skipped': 0,
                    'execute': 0,
                    'fail_recoverable': 0,
                    'fail_nonrecoverable': 0,
                    'scored': 0,
                    'not_scored': 0,
                    'instances': 0
                }
                aggregated_counters[(service, 'service')] = Counter(data)

            self.log.info("Aggregating heartbeat data...")
            for component_parts, counters_list in self.rolling_window.items():
                c_name, c_type, c_host = component_parts

                # Expiring data outside of the window
                counters_list = counters_list[-self.window_size:]

                key = (c_name, c_type)
                if key not in aggregated_counters:
                    aggregated_counters[key] = Counter()

                aggregated_counters[key]['instances'] += 1

                for c in counters_list:
                    aggregated_counters[key].update(c)

            self.log.info("Generating heartbeats...")
            for aggregated_parts, counter in aggregated_counters.items():
                agg_c_name, agg_c_type = aggregated_parts
                with elasticapm.capture_span(name=f"{agg_c_type}.{agg_c_name}",
                                             span_type="send_heartbeat"):

                    metrics_data = {}
                    for key, value in counter.items():
                        # Skip counts, they will be paired with a time entry and we only want to count it once
                        if key.endswith('.c'):
                            continue
                        # We have an entry that is a timer, should also have a .c count
                        elif key.endswith('.t'):
                            name = key.rstrip('.t')
                            metrics_data[name] = value / max(
                                counter.get(name + ".c", 1), 1)
                            metrics_data[name + "_count"] = counter.get(
                                name + ".c", 0)
                        # Plain old metric, no modifications needed
                        else:
                            metrics_data[key] = value

                    agg_c_instances = metrics_data.pop('instances', 1)
                    metrics_data.pop('instances_count', None)
                    self.hm.send_heartbeat(agg_c_type, agg_c_name,
                                           metrics_data, agg_c_instances)

            # APM Transaction end
            if self.apm_client:
                self.apm_client.end_transaction('send_heartbeats', 'success')

        except Exception:
            self.log.exception(
                "Unknown exception occurred during heartbeat creation:")
class MetricsServer(ServerBase):
    """
    There can only be one of these type of metrics server running because it runs of a pubsub queue.
    """
    def __init__(self, config=None):
        super().__init__('assemblyline.metrics_aggregator',
                         shutdown_timeout=65)
        self.config = config or forge.get_config()
        self.elastic_hosts = self.config.core.metrics.elasticsearch.hosts
        self.is_datastream = False

        if not self.elastic_hosts:
            self.log.error(
                "No elasticsearch cluster defined to store metrics. All gathered stats will be ignored..."
            )
            sys.exit(1)

        self.scheduler = BackgroundScheduler(daemon=True)
        self.metrics_queue = None
        self.es = None
        self.counters_lock = Lock()
        self.counters = {}

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="metrics_aggregator")
        else:
            self.apm_client = None

    def try_run(self):
        # If our connection to the metrics database requires a custom ca cert, prepare it
        ca_certs = None
        if self.config.core.metrics.elasticsearch.host_certificates:
            with tempfile.NamedTemporaryFile(delete=False) as ca_certs_file:
                ca_certs = ca_certs_file.name
                ca_certs_file.write(self.config.core.metrics.elasticsearch.
                                    host_certificates.encode())

        self.metrics_queue = CommsQueue(METRICS_QUEUE)
        self.es = elasticsearch.Elasticsearch(
            hosts=self.elastic_hosts,
            connection_class=elasticsearch.RequestsHttpConnection,
            ca_certs=ca_certs)
        # Determine if ES will support data streams (>= 7.9)
        self.is_datastream = version.parse(
            self.es.info()['version']['number']) >= version.parse("7.9")

        self.scheduler.add_job(self._create_aggregated_metrics,
                               'interval',
                               seconds=60)
        self.scheduler.start()

        while self.running:
            for msg in self.metrics_queue.listen():
                # APM Transaction start
                if self.apm_client:
                    self.apm_client.begin_transaction('metrics')

                m_name = msg.pop('name', None)
                m_type = msg.pop('type', None)
                msg.pop('host', None)
                msg.pop('instance', None)

                self.log.debug(f"Received {m_type.upper()} metrics message")
                if not m_name or not m_type:
                    # APM Transaction end
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'process_message', 'invalid_message')

                    continue

                with self.counters_lock:
                    c_key = (m_name, m_type)
                    if c_key not in self.counters or m_type in NON_AGGREGATED:
                        self.counters[c_key] = Counter(msg)
                    else:
                        non_agg_values = {}
                        if m_type in NON_AGGREGATED_COUNTERS:
                            non_agg_values = {
                                k: v
                                for k, v in msg.items()
                                if k in NON_AGGREGATED_COUNTERS[m_type]
                            }
                        self.counters[c_key].update(Counter(msg))
                        for k, v in non_agg_values.items():
                            self.counters[c_key][k] = v

                # APM Transaction end
                if self.apm_client:
                    self.apm_client.end_transaction('process_message',
                                                    'success')

    def _create_aggregated_metrics(self):
        self.log.info("Copying counters ...")
        # APM Transaction start
        if self.apm_client:
            self.apm_client.begin_transaction('metrics')

        with self.counters_lock:
            counter_copy, self.counters = self.counters, {}

        self.log.info("Aggregating metrics ...")
        timestamp = now_as_iso()
        for component, counts in counter_copy.items():
            component_name, component_type = component
            output_metrics = {'name': component_name, 'type': component_type}

            for key, value in counts.items():
                # Skip counts, they will be paired with a time entry and we only want to count it once
                if key.endswith('.c'):
                    continue
                # We have an entry that is a timer, should also have a .c count
                elif key.endswith('.t'):
                    name = key.rstrip('.t')
                    output_metrics[name] = counts[key] / counts.get(
                        name + ".c", 1)
                    output_metrics[name + "_count"] = counts.get(
                        name + ".c", 0)
                # Plain old metric, no modifications needed
                else:
                    output_metrics[key] = value

            ensure_indexes(self.log,
                           self.es,
                           self.config.core.metrics.elasticsearch,
                           [component_type],
                           datastream_enabled=self.is_datastream)

            index = f"al_metrics_{component_type}"
            # Were data streams created for the index specified?
            try:
                if self.es.indices.get_index_template(name=f"{index}_ds"):
                    output_metrics['@timestamp'] = timestamp
                    index = f"{index}_ds"
            except elasticsearch.exceptions.TransportError:
                pass
            output_metrics['timestamp'] = timestamp
            output_metrics = cleanup_metrics(output_metrics)

            self.log.info(output_metrics)
            with_retries(self.log,
                         self.es.index,
                         index=index,
                         body=output_metrics)

        self.log.info("Metrics aggregated. Waiting for next run...")

        # APM Transaction end
        if self.apm_client:
            self.apm_client.end_transaction('aggregate_metrics', 'success')
Ejemplo n.º 23
0
def get_submission_traffic_channel():
    return CommsQueue('submissions',
                      host=config.core.redis.nonpersistent.host,
                      port=config.core.redis.nonpersistent.port)
Ejemplo n.º 24
0
class Ingester(ThreadedCoreBase):
    def __init__(self,
                 datastore=None,
                 logger=None,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester',
                 config=None):
        super().__init__('assemblyline.ingester',
                         logger,
                         redis=redis,
                         redis_persist=persistent_redis,
                         datastore=datastore,
                         config=config)

        # Cache the user groups
        self.cache_lock = threading.RLock()
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value: dict[str, int] = constants.PRIORITIES
        self.priority_range: dict[str, Tuple[int,
                                             int]] = constants.PRIORITY_RANGES
        self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.redis_persist)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.redis_persist)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.redis_persist)

        # Internal, timeout watch queue
        self.timeout_queue: PriorityQueue[str] = PriorityQueue(
            'm-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.redis_persist)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="ingester")
        else:
            self.apm_client = None

    def try_run(self):
        threads_to_maintain = {
            'Retries': self.handle_retries,
            'Timeouts': self.handle_timeouts
        }
        threads_to_maintain.update({
            f'Complete_{n}': self.handle_complete
            for n in range(COMPLETE_THREADS)
        })
        threads_to_maintain.update(
            {f'Ingest_{n}': self.handle_ingest
             for n in range(INGEST_THREADS)})
        threads_to_maintain.update(
            {f'Submit_{n}': self.handle_submit
             for n in range(SUBMIT_THREADS)})
        self.maintain_threads(threads_to_maintain)

    def handle_ingest(self):
        cpu_mark = time.process_time()
        time_mark = time.time()

        # Move from ingest to unique and waiting queues.
        # While there are entries in the ingest queue we consume chunk_size
        # entries at a time and move unique entries to uniqueq / queued and
        # duplicates to their own queues / waiting.
        while self.running:
            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

            message = self.ingest_queue.pop(timeout=1)

            cpu_mark = time.process_time()
            time_mark = time.time()

            if not message:
                continue

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            try:
                if 'submission' in message:
                    # A retried task
                    task = IngestTask(message)
                else:
                    # A new submission
                    sub = MessageSubmission(message)
                    task = IngestTask(dict(
                        submission=sub,
                        ingest_id=sub.sid,
                    ))
                    task.submission.sid = None  # Reset to new random uuid
                    # Write all input to the traffic queue
                    self.traffic_queue.publish(
                        SubmissionMessage({
                            'msg': sub,
                            'msg_type': 'SubmissionIngested',
                            'sender': 'ingester',
                        }).as_primitives())

            except (ValueError, TypeError) as error:
                self.counter.increment('error')
                self.log.exception(
                    f"Dropped ingest submission {message} because {str(error)}"
                )

                # End of ingest message (value_error)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_input',
                                                    'value_error')
                continue

            self.ingest(task)

            # End of ingest message (success)
            if self.apm_client:
                self.apm_client.end_transaction('ingest_input', 'success')

    def handle_submit(self):
        time_mark, cpu_mark = time.time(), time.process_time()

        while self.running:
            # noinspection PyBroadException
            try:
                self.counter.increment_execution_time(
                    'cpu_seconds',
                    time.process_time() - cpu_mark)
                self.counter.increment_execution_time('busy_seconds',
                                                      time.time() - time_mark)

                # Check if there is room for more submissions
                length = self.scanning.length()
                if length >= self.config.core.ingester.max_inflight:
                    self.sleep(0.1)
                    time_mark, cpu_mark = time.time(), time.process_time()
                    continue

                raw = self.unique_queue.blocking_pop(timeout=3)
                time_mark, cpu_mark = time.time(), time.process_time()
                if not raw:
                    continue

                # Start of ingest message
                if self.apm_client:
                    self.apm_client.begin_transaction('ingest_msg')

                task = IngestTask(raw)

                # Check if we need to drop a file for capacity reasons, but only if the
                # number of files in flight is alreay over 80%
                if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop(
                        task):
                    # End of ingest message (dropped)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'dropped')
                    continue

                if self.is_whitelisted(task):
                    # End of ingest message (whitelisted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'whitelisted')
                    continue

                # Check if this file has been previously processed.
                pprevious, previous, score, scan_key = None, None, None, None
                if not task.submission.params.ignore_cache:
                    pprevious, previous, score, scan_key = self.check(task)
                else:
                    scan_key = self.stamp_filescore_key(task)

                # If it HAS been previously processed, we are dealing with a resubmission
                # finalize will decide what to do, and put the task back in the queue
                # rewritten properly if we are going to run it again
                if previous:
                    if not task.submission.params.services.resubmit and not pprevious:
                        self.log.warning(
                            f"No psid for what looks like a resubmission of "
                            f"{task.submission.files[0].sha256}: {scan_key}")
                    self.finalize(pprevious, previous, score, task)
                    # End of ingest message (finalized)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'finalized')

                    continue

                # We have decided this file is worth processing

                # Add the task to the scanning table, this is atomic across all submit
                # workers, so if it fails, someone beat us to the punch, record the file
                # as a duplicate then.
                if not self.scanning.add(scan_key, task.as_primitives()):
                    self.log.debug('Duplicate %s',
                                   task.submission.files[0].sha256)
                    self.counter.increment('duplicates')
                    self.duplicate_queue.push(_dup_prefix + scan_key,
                                              task.as_primitives())
                    # End of ingest message (duplicate)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'duplicate')

                    continue

                # We have managed to add the task to the scan table, so now we go
                # ahead with the submission process
                try:
                    self.submit(task)
                    # End of ingest message (submitted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'submitted')

                    continue
                except Exception as _ex:
                    # For some reason (contained in `ex`) we have failed the submission
                    # The rest of this function is error handling/recovery
                    ex = _ex
                    # traceback = _ex.__traceback__

                self.counter.increment('error')

                should_retry = True
                if isinstance(ex, CorruptedFileStoreException):
                    self.log.error(
                        "Submission for file '%s' failed due to corrupted "
                        "filestore: %s" % (task.sha256, str(ex)))
                    should_retry = False
                elif isinstance(ex, DataStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed due to "
                                   "data store error:\n%s" %
                                   (task.sha256, trace))
                elif not isinstance(ex, FileStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed: %s" %
                                   (task.sha256, trace))

                task = IngestTask(self.scanning.pop(scan_key))
                if not task:
                    self.log.error('No scanning entry for for %s', task.sha256)
                    # End of ingest message (no_scan_entry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'no_scan_entry')

                    continue

                if not should_retry:
                    # End of ingest message (cannot_retry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'cannot_retry')

                    continue

                self.retry(task, scan_key, ex)
                # End of ingest message (retry)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit', 'retried')

            except Exception:
                self.log.exception("Unexpected error")
                # End of ingest message (exception)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit',
                                                    'exception')

    def handle_complete(self):
        while self.running:
            result = self.complete_queue.pop(timeout=3)
            if not result:
                continue

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            sub = DatabaseSubmission(result)
            self.completed(sub)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(sid=sub.sid)
                self.apm_client.end_transaction('ingest_complete', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_retries(self):
        tasks = []
        while self.sleep(0 if tasks else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_retries')

            tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(),
                                                   num=100)

            for task in tasks:
                self.ingest_queue.push(task)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(retries=len(tasks))
                self.apm_client.end_transaction('ingest_retries', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_timeouts(self):
        timeouts = []
        while self.sleep(0 if timeouts else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_timeouts')

            timeouts = self.timeout_queue.dequeue_range(
                upper_limit=isotime.now(), num=100)

            for scan_key in timeouts:
                # noinspection PyBroadException
                try:
                    actual_timeout = False

                    # Remove the entry from the hash of submissions in progress.
                    entry = self.scanning.pop(scan_key)
                    if entry:
                        actual_timeout = True
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(entry))

                    dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                   blocking=False)
                    if dup:
                        actual_timeout = True

                    while dup:
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(dup))
                        dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                       blocking=False)

                    if actual_timeout:
                        self.counter.increment('timed_out')
                except Exception:
                    self.log.exception("Problem timing out %s:", scan_key)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(timeouts=len(timeouts))
                self.apm_client.end_transaction('ingest_timeouts', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def get_groups_from_user(self, username: str) -> List[str]:
        # Reset the group cache at the top of each hour
        if time.time() // HOUR_IN_SECONDS > self._user_groups_reset:
            self._user_groups = {}
            self._user_groups_reset = time.time() // HOUR_IN_SECONDS

        # Get the groups for this user if not known
        if username not in self._user_groups:
            user_data = self.datastore.user.get(username)
            if user_data:
                self._user_groups[username] = user_data.groups
            else:
                self._user_groups[username] = []
        return self._user_groups[username]

    def ingest(self, task: IngestTask):
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Task received for processing"
        )
        # Load a snapshot of ingest parameters as of right now.
        max_file_size = self.config.submission.max_file_size
        param = task.params

        self.counter.increment('bytes_ingested', increment_by=task.file_size)
        self.counter.increment('submissions_ingested')

        if any(len(file.sha256) != 64 for file in task.submission.files):
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped")
            self.send_notification(task,
                                   failure="Invalid sha256",
                                   logfunc=self.log.warning)
            return

        # Clean up metadata strings, since we may delete some, iterate on a copy of the keys
        for key in list(task.submission.metadata.keys()):
            value = task.submission.metadata[key]
            meta_size = len(value)
            if meta_size > self.config.submission.max_metadata_length:
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] '
                    f'Removing {key} from metadata because value is too big')
                task.submission.metadata.pop(key)

        if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop:
            task.failure = f"File too large ({task.file_size} > {max_file_size})"
            self._notify_drop(task)
            self.counter.increment('skipped')
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] {task.failure}")
            return

        # Set the groups from the user, if they aren't already set
        if not task.params.groups:
            task.params.groups = self.get_groups_from_user(
                task.params.submitter)

        # Check if this file is already being processed
        self.stamp_filescore_key(task)
        pprevious, previous, score = None, None, None
        if not param.ignore_cache:
            pprevious, previous, score, _ = self.check(task, count_miss=False)

        # Assign priority.
        low_priority = self.is_low_priority(task)

        priority = param.priority
        if priority < 0:
            priority = self.priority_value['medium']

            if score is not None:
                priority = self.priority_value['low']
                for level, threshold in self.threshold_value.items():
                    if score >= threshold:
                        priority = self.priority_value[level]
                        break
            elif low_priority:
                priority = self.priority_value['low']

        # Reduce the priority by an order of magnitude for very old files.
        current_time = now()
        if priority and self.expired(
                current_time - task.submission.time.timestamp(), 0):
            priority = (priority / 10) or 1

        param.priority = priority

        # Do this after priority has been assigned.
        # (So we don't end up dropping the resubmission).
        if previous:
            self.counter.increment('duplicates')
            self.finalize(pprevious, previous, score, task)

            # On cache hits of any kind we want to send out a completed message
            self.traffic_queue.publish(
                SubmissionMessage({
                    'msg': task.submission,
                    'msg_type': 'SubmissionCompleted',
                    'sender': 'ingester',
                }).as_primitives())
            return

        if self.drop(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped")
            return

        if self.is_whitelisted(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted")
            return

        self.unique_queue.push(priority, task.as_primitives())

    def check(
        self,
        task: IngestTask,
        count_miss=True
    ) -> Tuple[Optional[str], Optional[str], Optional[float], str]:
        key = self.stamp_filescore_key(task)

        with self.cache_lock:
            result = self.cache.get(key, None)

        if result:
            self.counter.increment('cache_hit_local')
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Local cache hit')
        else:
            result = self.datastore.filescore.get_if_exists(key)
            if result:
                self.counter.increment('cache_hit')
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] Remote cache hit')
            else:
                if count_miss:
                    self.counter.increment('cache_miss')
                return None, None, None, key

            with self.cache_lock:
                self.cache[key] = result

        current_time = now()
        age = current_time - result.time
        errors = result.errors

        if self.expired(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired"
            )
            self.counter.increment('cache_expired')
            self.cache.pop(key, None)
            self.datastore.filescore.delete(key)
            return None, None, None, key
        elif self.stale(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale"
            )
            self.counter.increment('cache_stale')
            return None, None, result.score, key

        return result.psid, result.sid, result.score, key

    def stop(self):
        super().stop()
        if self.apm_client:
            elasticapm.uninstrument()
        self.submit_client.stop()

    def stale(self, delta: float, errors: int):
        if errors:
            return delta >= self.config.core.ingester.incomplete_stale_after_seconds
        else:
            return delta >= self.config.core.ingester.stale_after_seconds

    @staticmethod
    def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str:
        if not sha256:
            sha256 = task.submission.files[0].sha256

        key = task.submission.scan_key

        if not key:
            key = task.params.create_filescore_key(sha256)
            task.submission.scan_key = key

        return key

    def completed(self, sub: DatabaseSubmission):
        """Invoked when notified that a submission has completed."""
        # There is only one file in the submissions we have made
        sha256 = sub.files[0].sha256
        scan_key = sub.scan_key
        if not scan_key:
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission missing scan key")
            scan_key = sub.params.create_filescore_key(sha256)
        raw = self.scanning.pop(scan_key)

        psid = sub.params.psid
        score = sub.max_score
        sid = sub.sid

        if not raw:
            # Some other worker has already popped the scanning queue?
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission completed twice")
            return scan_key

        task = IngestTask(raw)
        task.submission.sid = sid

        errors = sub.error_count
        file_count = sub.file_count
        self.counter.increment('submissions_completed')
        self.counter.increment('files_completed', increment_by=file_count)
        self.counter.increment('bytes_completed', increment_by=task.file_size)

        with self.cache_lock:
            fs = self.cache[scan_key] = FileScore({
                'expiry_ts':
                now(self.config.core.ingester.cache_dtl * 24 * 60 * 60),
                'errors':
                errors,
                'psid':
                psid,
                'score':
                score,
                'sid':
                sid,
                'time':
                now(),
            })
        self.datastore.filescore.save(scan_key, fs)

        self.finalize(psid, sid, score, task)

        def exhaust() -> Iterable[IngestTask]:
            while True:
                res = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                               blocking=False)
                if res is None:
                    break
                res = IngestTask(res)
                res.submission.sid = sid
                yield res

        # You may be tempted to remove the assignment to dups and use the
        # value directly in the for loop below. That would be a mistake.
        # The function finalize may push on the duplicate queue which we
        # are pulling off and so condensing those two lines creates a
        # potential infinite loop.
        dups = [dup for dup in exhaust()]
        for dup in dups:
            self.finalize(psid, sid, score, dup)

        return scan_key

    def send_notification(self, task: IngestTask, failure=None, logfunc=None):
        if logfunc is None:
            logfunc = self.log.info

        if failure:
            task.failure = failure

        failure = task.failure
        if failure:
            logfunc("%s: %s", failure, str(task.json()))

        if not task.submission.notification.queue:
            return

        note_queue = _notification_queue_prefix + task.submission.notification.queue
        threshold = task.submission.notification.threshold

        if threshold is not None and task.score is not None and task.score < threshold:
            return

        q = self.notification_queues.get(note_queue, None)
        if not q:
            self.notification_queues[note_queue] = q = NamedQueue(
                note_queue, self.redis_persist)
        q.push(task.as_primitives())

    def expired(self, delta: float, errors) -> bool:
        if errors:
            return delta >= self.config.core.ingester.incomplete_expire_after_seconds
        else:
            return delta >= self.config.core.ingester.expire_after

    def drop(self, task: IngestTask) -> bool:
        priority = task.params.priority
        sample_threshold = self.config.core.ingester.sampling_at

        dropped = False
        if priority <= _min_priority:
            dropped = True
        else:
            for level, rng in self.priority_range.items():
                if rng[0] <= priority <= rng[1] and level in sample_threshold:
                    dropped = must_drop(self.unique_queue.count(*rng),
                                        sample_threshold[level])
                    break

            if not dropped:
                if task.file_size > self.config.submission.max_file_size or task.file_size == 0:
                    dropped = True

        if task.params.never_drop or not dropped:
            return False

        task.failure = 'Skipped'
        self._notify_drop(task)
        self.counter.increment('skipped')
        return True

    def _notify_drop(self, task: IngestTask):
        self.send_notification(task)

        c12n = task.params.classification
        expiry = now_as_iso(86400)
        sha256 = task.submission.files[0].sha256

        self.datastore.save_or_freshen_file(sha256, {'sha256': sha256},
                                            expiry,
                                            c12n,
                                            redis=self.redis)

    def is_whitelisted(self, task: IngestTask):
        reason, hit = self.get_whitelist_verdict(self.whitelist, task)
        hit = {x: dotdump(safe_str(y)) for x, y in hit.items()}
        sha256 = task.submission.files[0].sha256

        if not reason:
            with self.whitelisted_lock:
                reason = self.whitelisted.get(sha256, None)
                if reason:
                    hit = 'cached'

        if reason:
            if hit != 'cached':
                with self.whitelisted_lock:
                    self.whitelisted[sha256] = reason

            task.failure = "Whitelisting due to reason %s (%s)" % (dotdump(
                safe_str(reason)), hit)
            self._notify_drop(task)

            self.counter.increment('whitelisted')

        return reason

    def submit(self, task: IngestTask):
        self.submit_client.submit(
            submission_obj=task.submission,
            completed_queue=COMPLETE_QUEUE_NAME,
        )

        self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key)
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis"
        )

    def retry(self, task: IngestTask, scan_key: str, ex):
        current_time = now()

        retries = task.retries + 1

        if retries > _max_retries:
            trace = ''
            if ex:
                trace = ': ' + get_stacktrace_info(ex)
            self.log.error(
                f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        elif self.expired(current_time - task.ingest_time.timestamp(), 0):
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        else:
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})'
            )
            task.retries = retries
            self.retry_queue.push(int(now(_retry_delay)), task.as_primitives())

    def finalize(self, psid: str, sid: str, score: float, task: IngestTask):
        self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed")
        if psid:
            task.params.psid = psid
        task.score = score
        task.submission.sid = sid

        selected = task.params.services.selected
        resubmit_to = task.params.services.resubmit

        resubmit_selected = determine_resubmit_selected(selected, resubmit_to)
        will_resubmit = resubmit_selected and should_resubmit(score)
        if will_resubmit:
            task.extended_scan = 'submitted'
            task.params.psid = None

        if self.is_alert(task, score):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Notifying alerter "
                f"to {'update' if task.params.psid else 'create'} an alert")
            self.alert_queue.push(task.as_primitives())

        self.send_notification(task)

        if will_resubmit:
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis"
            )
            task.params.psid = sid
            task.submission.sid = None
            task.submission.scan_key = None
            task.params.services.resubmit = []
            task.params.services.selected = resubmit_selected

            self.unique_queue.push(task.params.priority, task.as_primitives())

    def is_alert(self, task: IngestTask, score: float) -> bool:
        if not task.params.generate_alert:
            return False

        if score < self.config.core.alerter.threshold:
            return False

        return True
Ejemplo n.º 25
0
def test_status_namspace(datastore, sio):
    status_queue = CommsQueue('status', private=True)
    monitoring = get_random_id()

    alerter_hb_msg = random_model_obj(AlerterMessage).as_primitives()
    dispatcher_hb_msg = random_model_obj(DispatcherMessage).as_primitives()
    expiry_hb_msg = random_model_obj(ExpiryMessage).as_primitives()
    ingest_hb_msg = random_model_obj(IngestMessage).as_primitives()
    service_hb_msg = random_model_obj(ServiceMessage).as_primitives()

    test_res_array = []

    @sio.on('monitoring', namespace='/status')
    def on_monitoring(data):
        # Confirmation that we are waiting for status messages
        test_res_array.append(('on_monitoring', data == monitoring))

    @sio.on('AlerterHeartbeat', namespace='/status')
    def on_alerter_heartbeat(data):
        test_res_array.append(
            ('on_alerter_heartbeat', data == alerter_hb_msg['msg']))

    @sio.on('DispatcherHeartbeat', namespace='/status')
    def on_dispatcher_heartbeat(data):
        test_res_array.append(
            ('on_dispatcher_heartbeat', data == dispatcher_hb_msg['msg']))

    @sio.on('ExpiryHeartbeat', namespace='/status')
    def on_expiry_heartbeat(data):
        test_res_array.append(
            ('on_expiry_heartbeat', data == expiry_hb_msg['msg']))

    @sio.on('IngestHeartbeat', namespace='/status')
    def on_ingest_heartbeat(data):
        test_res_array.append(
            ('on_ingest_heartbeat', data == ingest_hb_msg['msg']))

    @sio.on('ServiceHeartbeat', namespace='/status')
    def on_service_heartbeat(data):
        test_res_array.append(
            ('on_service_heartbeat', data == service_hb_msg['msg']))

    try:
        sio.emit('monitor', monitoring, namespace='/status')
        sio.sleep(1)

        status_queue.publish(alerter_hb_msg)
        status_queue.publish(dispatcher_hb_msg)
        status_queue.publish(expiry_hb_msg)
        status_queue.publish(ingest_hb_msg)
        status_queue.publish(service_hb_msg)

        start_time = time.time()

        while len(test_res_array) < 6 and time.time() - start_time < 5:
            sio.sleep(0.1)

        assert len(test_res_array) == 6

        for test, result in test_res_array:
            if not result:
                pytest.fail(f"{test} failed.")
    finally:
        sio.disconnect()