def try_run(self):
        counter = self.counter
        apm_client = self.apm_client

        while self.running:
            self.heartbeat()

            # Download all messages from the queue that have expired
            seconds, _ = retry_call(self.redis.time)
            messages = self.queue.dequeue_range(0, seconds)

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Try to pass on all the messages to their intended recipient, try not to let
            # the failure of one message from preventing the others from going through
            for key in messages:
                # Start of transaction
                if apm_client:
                    apm_client.begin_transaction('process_messages')

                message = self.hash.pop(key)
                if message:
                    try:
                        if message['action'] == WatcherAction.TimeoutTask:
                            self.cancel_service_task(message['task_key'],
                                                     message['worker'])
                        else:
                            queue = NamedQueue(message['queue'], self.redis)
                            queue.push(message['message'])

                        self.counter.increment('expired')
                        # End of transaction (success)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'success')
                    except Exception as error:
                        # End of transaction (exception)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'error')

                        self.log.exception(error)
                else:
                    # End of transaction (duplicate)
                    if apm_client:
                        apm_client.end_transaction('watch_message',
                                                   'duplicate')

                    self.log.warning(
                        f'Handled watch twice: {key} {len(key)} {type(key)}')

            counter.increment_execution_time('cpu_seconds',
                                             time.process_time() - cpu_mark)
            counter.increment_execution_time('busy_seconds',
                                             time.time() - time_mark)

            if not messages:
                time.sleep(0.1)
def default_authenticator(auth, req, ses, storage):
    # This is assemblyline authentication procedure
    # It will try to authenticate the user in the following order until a method is successful
    #    apikey
    #    username/password
    #    PKI DN
    #
    # During the authentication procedure the user/pass and DN methods will be subject to OTP challenge
    # if OTP is allowed on the server and has been turned on by the user
    #
    # Apikey authentication procedure is not subject to OTP challenge but has limited functionality

    apikey = auth.get('apikey', None)
    otp = auth.get('otp', 0)
    webauthn_auth_resp = auth.get('webauthn_auth_resp', None)
    state = ses.pop('state', None)
    password = auth.get('password', None)
    uname = auth.get('username', None)
    oauth_token = auth.get('oauth_token', None)

    if not uname:
        raise AuthenticationException('No user specified for authentication')

    # Bruteforce protection
    auth_fail_queue = NamedQueue("ui-failed-%s" % uname, **nonpersistent_config)
    if auth_fail_queue.length() >= config.auth.internal.max_failures:
        # Failed 'max_failures' times, stop trying... This will timeout in 'failure_ttl' seconds
        raise AuthenticationException("Maximum password retry of {retry} was reached. "
                                      "This account is locked for the next {ttl} "
                                      "seconds...".format(retry=config.auth.internal.max_failures,
                                                          ttl=config.auth.internal.failure_ttl))

    try:
        validated_user, priv = validate_apikey(uname, apikey, storage)
        if validated_user:
            return validated_user, priv

        validated_user, priv = validate_oauth(uname, oauth_token)
        if not validated_user:
            validated_user, priv = validate_ldapuser(uname, password, storage)
        if not validated_user:
            validated_user, priv = validate_userpass(uname, password, storage)

        if validated_user:
            validate_2fa(validated_user, otp, state, webauthn_auth_resp, storage)
            return validated_user, priv

    except AuthenticationException:
        # Failure appended, push failure parameters
        auth_fail_queue.push({
            'remote_addr': req.remote_addr,
            'host': req.host,
            'full_path': req.full_path
        })

        raise

    raise AuthenticationException("None of the authentication methods succeeded")
def test_get_message(datastore, client):
    notification_queue = get_random_id()
    queue = NamedQueue("nq-%s" % notification_queue,
                       host=config.core.redis.persistent.host,
                       port=config.core.redis.persistent.port)
    queue.delete()
    msg = random_model_obj(Submission).as_primitives()
    queue.push(msg)

    res = client.ingest.get_message(notification_queue)
    assert isinstance(res, dict)
    assert 'sid' in res
    assert 'results' in res
    assert res == msg
def test_get_message_list(datastore, client):
    notification_queue = get_random_id()
    queue = NamedQueue("nq-%s" % notification_queue,
                       host=config.core.redis.persistent.host,
                       port=config.core.redis.persistent.port)
    queue.delete()
    msg_0 = random_model_obj(Submission).as_primitives()
    queue.push(msg_0)
    msg_1 = random_model_obj(Submission).as_primitives()
    queue.push(msg_1)

    res = client.ingest.get_message_list(notification_queue)
    assert len(res) == 2
    assert res[0] == msg_0
    assert res[1] == msg_1
Exemple #5
0
def backup_worker(worker_id, instance_id, working_dir):
    datastore = forge.get_datastore(archive_access=True)
    worker_queue = NamedQueue(f"r-worker-{instance_id}", ttl=1800)
    done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800)
    hash_queue = Hash(f"r-hash-{instance_id}")
    stopping = False
    with open(os.path.join(working_dir, "backup.part%s" % worker_id),
              "w+") as backup_file:
        while True:
            data = worker_queue.pop(timeout=1)
            if data is None:
                if stopping:
                    break
                continue

            if data.get('stop', False):
                if not stopping:
                    stopping = True
                else:
                    time.sleep(round(random.uniform(0.050, 0.250), 3))
                    worker_queue.push(data)
                continue

            missing = False
            success = True
            try:
                to_write = datastore.get_collection(data['bucket_name']).get(
                    data['key'], as_obj=False)
                if to_write:
                    if data.get('follow_keys', False):
                        for bucket, bucket_key, getter in FOLLOW_KEYS.get(
                                data['bucket_name'], []):
                            for key in getter(to_write.get(bucket_key, None)):
                                hash_key = "%s_%s" % (bucket, key)
                                if not hash_queue.exists(hash_key):
                                    hash_queue.add(hash_key, "True")
                                    worker_queue.push({
                                        "bucket_name": bucket,
                                        "key": key,
                                        "follow_keys": True
                                    })

                    backup_file.write(
                        json.dumps((data['bucket_name'], data['key'],
                                    to_write)) + "\n")
                else:
                    missing = True
            except Exception:
                success = False

            done_queue.push({
                "success": success,
                "missing": missing,
                "bucket_name": data['bucket_name'],
                "key": data['key']
            })

    done_queue.push({"stopped": True})
    def setup_watch_queue(self, sid):
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        # Create a unique queue
        queue_name = reply_queue_name(prefix="D", suffix="WQ")
        watch_queue = NamedQueue(queue_name, ttl=30)
        watch_queue.push(
            WatchQueueMessage({
                'status': 'START'
            }).as_primitives())

        # Add the newly created queue to the list of queues for the given submission
        self._get_watcher_list(sid).add(queue_name)

        # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute)
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        if dispatch_hash.dispatch_count(
        ) == 0 and dispatch_hash.finished_count() == 0:
            # This table is empty? do we have this submission at all?
            submission = self.ds.submission.get(sid)
            if not submission or submission.state == 'completed':
                watch_queue.push(
                    WatchQueueMessage({
                        "status": "STOP"
                    }).as_primitives())
            else:
                # We do have a submission, remind the dispatcher to work on it
                self.submission_queue.push({'sid': sid})

        else:
            all_service_status = dispatch_hash.all_results()
            for status_values in all_service_status.values():
                for status in status_values.values():
                    if status.is_error:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "FAIL",
                                "cache_key": status.key
                            }).as_primitives())
                    else:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "OK",
                                "cache_key": status.key
                            }).as_primitives())

        return queue_name
Exemple #7
0
def test_live_namespace(datastore, sio):
    wq_data = {'wq_id': get_random_id()}
    wq = NamedQueue(wq_data['wq_id'], private=True)

    start_msg = {'status_code': 200, 'msg': "Start listening..."}
    stop_msg = {
        'status_code': 200,
        'msg': "All messages received, closing queue..."
    }
    cachekey_msg = {'status_code': 200, 'msg': get_random_id()}
    cachekeyerr_msg = {'status_code': 200, 'msg': get_random_id()}

    test_res_array = []

    @sio.on('start', namespace='/live_submission')
    def on_start(data):
        test_res_array.append(('on_start', data == start_msg))

    @sio.on('stop', namespace='/live_submission')
    def on_stop(data):
        test_res_array.append(('on_stop', data == stop_msg))

    @sio.on('cachekey', namespace='/live_submission')
    def on_cachekey(data):
        test_res_array.append(('on_cachekey', data == cachekey_msg))

    @sio.on('cachekeyerr', namespace='/live_submission')
    def on_stop(data):
        test_res_array.append(('on_cachekeyerr', data == cachekeyerr_msg))

    try:
        sio.emit('listen', wq_data, namespace='/live_submission')
        sio.sleep(1)

        wq.push({"status": "START"})
        wq.push({"status": "OK", "cache_key": cachekey_msg['msg']})
        wq.push({"status": "FAIL", "cache_key": cachekeyerr_msg['msg']})
        wq.push({"status": "STOP"})

        start_time = time.time()

        while len(test_res_array) < 4 and time.time() - start_time < 5:
            sio.sleep(0.1)

        assert len(test_res_array) == 4

        for test, result in test_res_array:
            if not result:
                pytest.fail(f"{test} failed.")

    finally:
        sio.disconnect()
    def setup_watch_queue(self, sid: str) -> Optional[str]:
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="WQ")
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': CREATE_WATCH,
                'payload_data': CreateWatch({
                    'queue_name': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue_name
    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue")
            queue = NamedQueue(queue_name, host=self.redis, ttl=30)
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': LIST_OUTSTANDING,
                'payload_data': ListOutstanding({
                    'response_queue': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue.pop(timeout=30)
        return {}
Exemple #10
0
def restore_worker(worker_id, instance_id, working_dir):
    datastore = forge.get_datastore(archive_access=True)
    done_queue = NamedQueue(f"r-done-{instance_id}", ttl=1800)

    with open(os.path.join(working_dir, "backup.part%s" % worker_id),
              "rb") as input_file:
        for line in input_file:
            bucket_name, key, data = json.loads(line)

            success = True
            try:
                collection = datastore.get_collection(bucket_name)
                collection.save(key, data)
            except Exception:
                success = False

            done_queue.push({
                "success": success,
                "missing": False,
                "bucket_name": bucket_name,
                "key": key
            })

    done_queue.push({"stopped": True})
class DispatchClient:
    def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger("assemblyline.dispatching.client")
        self.results = self.ds.result
        self.errors = self.ds.error
        self.files = self.ds.file
        self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)
        self.service_data = cast(Dict[str, Service], CachedObject(self._get_services))
        self.dispatcher_data = []
        self.dispatcher_data_age = 0.0
        self.dead_dispatchers = []

    @weak_lru(maxsize=128)
    def _get_queue_from_cache(self, name):
        return NamedQueue(name, host=self.redis, ttl=QUEUE_EXPIRY)

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def is_dispatcher(self, dispatcher_id) -> bool:
        if dispatcher_id in self.dead_dispatchers:
            return False
        if time.time() - self.dispatcher_data_age > 120 or dispatcher_id not in self.dispatcher_data:
            self.dispatcher_data = Dispatcher.all_instances(self.redis_persist)
            self.dispatcher_data_age = time.time()
        if dispatcher_id in self.dispatcher_data:
            return True
        else:
            self.dead_dispatchers.append(dispatcher_id)
            return False

    def dispatch_bundle(self, submission: Submission, results: Dict[str, Result],
                        file_infos: Dict[str, File], file_tree, errors: Dict[str, Error], completed_queue: str = None):
        """Insert a bundle into the dispatching system and continue scanning of its files

        Prerequisites:
            - Submission, results, file_infos and errors should already be saved in the datastore
            - Files should already be in the filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            results=results,
            file_infos=file_infos,
            file_tree=file_tree,
            errors=errors,
            completed_queue=completed_queue,
        ))

    def dispatch_submission(self, submission: Submission, completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            completed_queue=completed_queue,
        ))

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue")
            queue = NamedQueue(queue_name, host=self.redis, ttl=30)
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': LIST_OUTSTANDING,
                'payload_data': ListOutstanding({
                    'response_queue': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue.pop(timeout=30)
        return {}

    @elasticapm.capture_span(span_type='dispatch_client')
    def request_work(self, worker_id, service_name, service_version,
                     timeout: float = 60, blocking=True, low_priority=False) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            work = self._request_work(worker_id, service_name, service_version,
                                      blocking=blocking, timeout=remaining, low_priority=low_priority)
            if work or not blocking:
                return work
            remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version,
                      timeout, blocking, low_priority=False) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout), low_priority=low_priority)
        else:
            if low_priority:
                result = work_queue.unpush(1)
            else:
                result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(f"{service_name}:{worker_id} no task returned: [empty message]")
            return None
        task = ServiceTask(result)
        task.metadata['worker__'] = worker_id
        dispatcher = task.metadata['dispatcher__']

        if not self.is_dispatcher(dispatcher):
            self.log.info(f"{service_name}:{worker_id} no task returned: [task from dead dispatcher]")
            return None

        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found")
            start_queue = self._get_queue_from_cache(DISPATCH_START_EVENTS + dispatcher)
            start_queue.push((task.sid, task.fileinfo.sha256, service_name, worker_id))
            return task
        return None

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_finished(self, sid: str, result_key: str, result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(sid=sid, service_name=result.response.service_name, sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing successful results.")
            return
        task = ServiceTask(task)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts})
        else:
            while True:
                old, version = self.ds.result.get_if_exists(
                    result_key, archive_access=self.config.datastore.ilm.update_archive, version=True)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                try:
                    self.ds.result.save(result_key, result, version=version)
                    break
                except VersionConflictException as vce:
                    self.log.info(f"Retrying to save results due to version conflict: {str(vce)}")

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w, host=self.redis).push(msg)

        # Save the tags
        tags = []
        for section in result.result.sections:
            tags.extend(tag_dict_to_list(flatten(section.tags.as_primitives())))

        # Pull out file names if we have them
        file_names = {}
        for extracted_data in result.response.extracted:
            if extracted_data.name:
                file_names[extracted_data.sha256] = extracted_data.name

        #
        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        ex_ts = result.expiry_ts.strftime(DATEFORMAT) if result.expiry_ts else result.archive_ts.strftime(DATEFORMAT)
        result_queue.push({
            # 'service_task': task.as_primitives(),
            # 'result': result.as_primitives(),
            'sid': task.sid,
            'sha256': result.sha256,
            'service_name': task.service_name,
            'service_version': result.response.service_version,
            'service_tool_version': result.response.service_tool_version,
            'archive_ts': result.archive_ts.strftime(DATEFORMAT),
            'expiry_ts': ex_ts,
            'result_summary': {
                'key': result_key,
                'drop': result.drop_file,
                'score': result.result.score,
                'children': [r.sha256 for r in result.response.extracted],
            },
            'tags': tags,
            'extracted_names': file_names,
            'temporary_data': temporary_data
        })

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(sid=sid, service_name=error.response.service_name, sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error.")
        if error.response.status == "FAIL_NONRECOVERABLE":
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w, host=self.redis).push(msg)

        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        result_queue.push({
            'sid': task.sid,
            'service_task': task.as_primitives(),
            'error': error.as_primitives(),
            'error_key': error_key
        })

    def setup_watch_queue(self, sid: str) -> Optional[str]:
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="WQ")
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': CREATE_WATCH,
                'payload_data': CreateWatch({
                    'queue_name': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
class Ingester(ThreadedCoreBase):
    def __init__(self,
                 datastore=None,
                 logger=None,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester',
                 config=None):
        super().__init__('assemblyline.ingester',
                         logger,
                         redis=redis,
                         redis_persist=persistent_redis,
                         datastore=datastore,
                         config=config)

        # Cache the user groups
        self.cache_lock = threading.RLock()
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value: dict[str, int] = constants.PRIORITIES
        self.priority_range: dict[str, Tuple[int,
                                             int]] = constants.PRIORITY_RANGES
        self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.redis_persist)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.redis_persist)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.redis_persist)

        # Internal, timeout watch queue
        self.timeout_queue: PriorityQueue[str] = PriorityQueue(
            'm-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.redis_persist)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="ingester")
        else:
            self.apm_client = None

    def try_run(self):
        threads_to_maintain = {
            'Retries': self.handle_retries,
            'Timeouts': self.handle_timeouts
        }
        threads_to_maintain.update({
            f'Complete_{n}': self.handle_complete
            for n in range(COMPLETE_THREADS)
        })
        threads_to_maintain.update(
            {f'Ingest_{n}': self.handle_ingest
             for n in range(INGEST_THREADS)})
        threads_to_maintain.update(
            {f'Submit_{n}': self.handle_submit
             for n in range(SUBMIT_THREADS)})
        self.maintain_threads(threads_to_maintain)

    def handle_ingest(self):
        cpu_mark = time.process_time()
        time_mark = time.time()

        # Move from ingest to unique and waiting queues.
        # While there are entries in the ingest queue we consume chunk_size
        # entries at a time and move unique entries to uniqueq / queued and
        # duplicates to their own queues / waiting.
        while self.running:
            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

            message = self.ingest_queue.pop(timeout=1)

            cpu_mark = time.process_time()
            time_mark = time.time()

            if not message:
                continue

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            try:
                if 'submission' in message:
                    # A retried task
                    task = IngestTask(message)
                else:
                    # A new submission
                    sub = MessageSubmission(message)
                    task = IngestTask(dict(
                        submission=sub,
                        ingest_id=sub.sid,
                    ))
                    task.submission.sid = None  # Reset to new random uuid
                    # Write all input to the traffic queue
                    self.traffic_queue.publish(
                        SubmissionMessage({
                            'msg': sub,
                            'msg_type': 'SubmissionIngested',
                            'sender': 'ingester',
                        }).as_primitives())

            except (ValueError, TypeError) as error:
                self.counter.increment('error')
                self.log.exception(
                    f"Dropped ingest submission {message} because {str(error)}"
                )

                # End of ingest message (value_error)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_input',
                                                    'value_error')
                continue

            self.ingest(task)

            # End of ingest message (success)
            if self.apm_client:
                self.apm_client.end_transaction('ingest_input', 'success')

    def handle_submit(self):
        time_mark, cpu_mark = time.time(), time.process_time()

        while self.running:
            # noinspection PyBroadException
            try:
                self.counter.increment_execution_time(
                    'cpu_seconds',
                    time.process_time() - cpu_mark)
                self.counter.increment_execution_time('busy_seconds',
                                                      time.time() - time_mark)

                # Check if there is room for more submissions
                length = self.scanning.length()
                if length >= self.config.core.ingester.max_inflight:
                    self.sleep(0.1)
                    time_mark, cpu_mark = time.time(), time.process_time()
                    continue

                raw = self.unique_queue.blocking_pop(timeout=3)
                time_mark, cpu_mark = time.time(), time.process_time()
                if not raw:
                    continue

                # Start of ingest message
                if self.apm_client:
                    self.apm_client.begin_transaction('ingest_msg')

                task = IngestTask(raw)

                # Check if we need to drop a file for capacity reasons, but only if the
                # number of files in flight is alreay over 80%
                if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop(
                        task):
                    # End of ingest message (dropped)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'dropped')
                    continue

                if self.is_whitelisted(task):
                    # End of ingest message (whitelisted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'whitelisted')
                    continue

                # Check if this file has been previously processed.
                pprevious, previous, score, scan_key = None, None, None, None
                if not task.submission.params.ignore_cache:
                    pprevious, previous, score, scan_key = self.check(task)
                else:
                    scan_key = self.stamp_filescore_key(task)

                # If it HAS been previously processed, we are dealing with a resubmission
                # finalize will decide what to do, and put the task back in the queue
                # rewritten properly if we are going to run it again
                if previous:
                    if not task.submission.params.services.resubmit and not pprevious:
                        self.log.warning(
                            f"No psid for what looks like a resubmission of "
                            f"{task.submission.files[0].sha256}: {scan_key}")
                    self.finalize(pprevious, previous, score, task)
                    # End of ingest message (finalized)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'finalized')

                    continue

                # We have decided this file is worth processing

                # Add the task to the scanning table, this is atomic across all submit
                # workers, so if it fails, someone beat us to the punch, record the file
                # as a duplicate then.
                if not self.scanning.add(scan_key, task.as_primitives()):
                    self.log.debug('Duplicate %s',
                                   task.submission.files[0].sha256)
                    self.counter.increment('duplicates')
                    self.duplicate_queue.push(_dup_prefix + scan_key,
                                              task.as_primitives())
                    # End of ingest message (duplicate)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'duplicate')

                    continue

                # We have managed to add the task to the scan table, so now we go
                # ahead with the submission process
                try:
                    self.submit(task)
                    # End of ingest message (submitted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'submitted')

                    continue
                except Exception as _ex:
                    # For some reason (contained in `ex`) we have failed the submission
                    # The rest of this function is error handling/recovery
                    ex = _ex
                    # traceback = _ex.__traceback__

                self.counter.increment('error')

                should_retry = True
                if isinstance(ex, CorruptedFileStoreException):
                    self.log.error(
                        "Submission for file '%s' failed due to corrupted "
                        "filestore: %s" % (task.sha256, str(ex)))
                    should_retry = False
                elif isinstance(ex, DataStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed due to "
                                   "data store error:\n%s" %
                                   (task.sha256, trace))
                elif not isinstance(ex, FileStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed: %s" %
                                   (task.sha256, trace))

                task = IngestTask(self.scanning.pop(scan_key))
                if not task:
                    self.log.error('No scanning entry for for %s', task.sha256)
                    # End of ingest message (no_scan_entry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'no_scan_entry')

                    continue

                if not should_retry:
                    # End of ingest message (cannot_retry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'cannot_retry')

                    continue

                self.retry(task, scan_key, ex)
                # End of ingest message (retry)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit', 'retried')

            except Exception:
                self.log.exception("Unexpected error")
                # End of ingest message (exception)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit',
                                                    'exception')

    def handle_complete(self):
        while self.running:
            result = self.complete_queue.pop(timeout=3)
            if not result:
                continue

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            sub = DatabaseSubmission(result)
            self.completed(sub)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(sid=sub.sid)
                self.apm_client.end_transaction('ingest_complete', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_retries(self):
        tasks = []
        while self.sleep(0 if tasks else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_retries')

            tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(),
                                                   num=100)

            for task in tasks:
                self.ingest_queue.push(task)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(retries=len(tasks))
                self.apm_client.end_transaction('ingest_retries', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_timeouts(self):
        timeouts = []
        while self.sleep(0 if timeouts else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_timeouts')

            timeouts = self.timeout_queue.dequeue_range(
                upper_limit=isotime.now(), num=100)

            for scan_key in timeouts:
                # noinspection PyBroadException
                try:
                    actual_timeout = False

                    # Remove the entry from the hash of submissions in progress.
                    entry = self.scanning.pop(scan_key)
                    if entry:
                        actual_timeout = True
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(entry))

                    dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                   blocking=False)
                    if dup:
                        actual_timeout = True

                    while dup:
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(dup))
                        dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                       blocking=False)

                    if actual_timeout:
                        self.counter.increment('timed_out')
                except Exception:
                    self.log.exception("Problem timing out %s:", scan_key)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(timeouts=len(timeouts))
                self.apm_client.end_transaction('ingest_timeouts', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def get_groups_from_user(self, username: str) -> List[str]:
        # Reset the group cache at the top of each hour
        if time.time() // HOUR_IN_SECONDS > self._user_groups_reset:
            self._user_groups = {}
            self._user_groups_reset = time.time() // HOUR_IN_SECONDS

        # Get the groups for this user if not known
        if username not in self._user_groups:
            user_data = self.datastore.user.get(username)
            if user_data:
                self._user_groups[username] = user_data.groups
            else:
                self._user_groups[username] = []
        return self._user_groups[username]

    def ingest(self, task: IngestTask):
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Task received for processing"
        )
        # Load a snapshot of ingest parameters as of right now.
        max_file_size = self.config.submission.max_file_size
        param = task.params

        self.counter.increment('bytes_ingested', increment_by=task.file_size)
        self.counter.increment('submissions_ingested')

        if any(len(file.sha256) != 64 for file in task.submission.files):
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped")
            self.send_notification(task,
                                   failure="Invalid sha256",
                                   logfunc=self.log.warning)
            return

        # Clean up metadata strings, since we may delete some, iterate on a copy of the keys
        for key in list(task.submission.metadata.keys()):
            value = task.submission.metadata[key]
            meta_size = len(value)
            if meta_size > self.config.submission.max_metadata_length:
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] '
                    f'Removing {key} from metadata because value is too big')
                task.submission.metadata.pop(key)

        if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop:
            task.failure = f"File too large ({task.file_size} > {max_file_size})"
            self._notify_drop(task)
            self.counter.increment('skipped')
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] {task.failure}")
            return

        # Set the groups from the user, if they aren't already set
        if not task.params.groups:
            task.params.groups = self.get_groups_from_user(
                task.params.submitter)

        # Check if this file is already being processed
        self.stamp_filescore_key(task)
        pprevious, previous, score = None, None, None
        if not param.ignore_cache:
            pprevious, previous, score, _ = self.check(task, count_miss=False)

        # Assign priority.
        low_priority = self.is_low_priority(task)

        priority = param.priority
        if priority < 0:
            priority = self.priority_value['medium']

            if score is not None:
                priority = self.priority_value['low']
                for level, threshold in self.threshold_value.items():
                    if score >= threshold:
                        priority = self.priority_value[level]
                        break
            elif low_priority:
                priority = self.priority_value['low']

        # Reduce the priority by an order of magnitude for very old files.
        current_time = now()
        if priority and self.expired(
                current_time - task.submission.time.timestamp(), 0):
            priority = (priority / 10) or 1

        param.priority = priority

        # Do this after priority has been assigned.
        # (So we don't end up dropping the resubmission).
        if previous:
            self.counter.increment('duplicates')
            self.finalize(pprevious, previous, score, task)

            # On cache hits of any kind we want to send out a completed message
            self.traffic_queue.publish(
                SubmissionMessage({
                    'msg': task.submission,
                    'msg_type': 'SubmissionCompleted',
                    'sender': 'ingester',
                }).as_primitives())
            return

        if self.drop(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped")
            return

        if self.is_whitelisted(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted")
            return

        self.unique_queue.push(priority, task.as_primitives())

    def check(
        self,
        task: IngestTask,
        count_miss=True
    ) -> Tuple[Optional[str], Optional[str], Optional[float], str]:
        key = self.stamp_filescore_key(task)

        with self.cache_lock:
            result = self.cache.get(key, None)

        if result:
            self.counter.increment('cache_hit_local')
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Local cache hit')
        else:
            result = self.datastore.filescore.get_if_exists(key)
            if result:
                self.counter.increment('cache_hit')
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] Remote cache hit')
            else:
                if count_miss:
                    self.counter.increment('cache_miss')
                return None, None, None, key

            with self.cache_lock:
                self.cache[key] = result

        current_time = now()
        age = current_time - result.time
        errors = result.errors

        if self.expired(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired"
            )
            self.counter.increment('cache_expired')
            self.cache.pop(key, None)
            self.datastore.filescore.delete(key)
            return None, None, None, key
        elif self.stale(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale"
            )
            self.counter.increment('cache_stale')
            return None, None, result.score, key

        return result.psid, result.sid, result.score, key

    def stop(self):
        super().stop()
        if self.apm_client:
            elasticapm.uninstrument()
        self.submit_client.stop()

    def stale(self, delta: float, errors: int):
        if errors:
            return delta >= self.config.core.ingester.incomplete_stale_after_seconds
        else:
            return delta >= self.config.core.ingester.stale_after_seconds

    @staticmethod
    def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str:
        if not sha256:
            sha256 = task.submission.files[0].sha256

        key = task.submission.scan_key

        if not key:
            key = task.params.create_filescore_key(sha256)
            task.submission.scan_key = key

        return key

    def completed(self, sub: DatabaseSubmission):
        """Invoked when notified that a submission has completed."""
        # There is only one file in the submissions we have made
        sha256 = sub.files[0].sha256
        scan_key = sub.scan_key
        if not scan_key:
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission missing scan key")
            scan_key = sub.params.create_filescore_key(sha256)
        raw = self.scanning.pop(scan_key)

        psid = sub.params.psid
        score = sub.max_score
        sid = sub.sid

        if not raw:
            # Some other worker has already popped the scanning queue?
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission completed twice")
            return scan_key

        task = IngestTask(raw)
        task.submission.sid = sid

        errors = sub.error_count
        file_count = sub.file_count
        self.counter.increment('submissions_completed')
        self.counter.increment('files_completed', increment_by=file_count)
        self.counter.increment('bytes_completed', increment_by=task.file_size)

        with self.cache_lock:
            fs = self.cache[scan_key] = FileScore({
                'expiry_ts':
                now(self.config.core.ingester.cache_dtl * 24 * 60 * 60),
                'errors':
                errors,
                'psid':
                psid,
                'score':
                score,
                'sid':
                sid,
                'time':
                now(),
            })
        self.datastore.filescore.save(scan_key, fs)

        self.finalize(psid, sid, score, task)

        def exhaust() -> Iterable[IngestTask]:
            while True:
                res = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                               blocking=False)
                if res is None:
                    break
                res = IngestTask(res)
                res.submission.sid = sid
                yield res

        # You may be tempted to remove the assignment to dups and use the
        # value directly in the for loop below. That would be a mistake.
        # The function finalize may push on the duplicate queue which we
        # are pulling off and so condensing those two lines creates a
        # potential infinite loop.
        dups = [dup for dup in exhaust()]
        for dup in dups:
            self.finalize(psid, sid, score, dup)

        return scan_key

    def send_notification(self, task: IngestTask, failure=None, logfunc=None):
        if logfunc is None:
            logfunc = self.log.info

        if failure:
            task.failure = failure

        failure = task.failure
        if failure:
            logfunc("%s: %s", failure, str(task.json()))

        if not task.submission.notification.queue:
            return

        note_queue = _notification_queue_prefix + task.submission.notification.queue
        threshold = task.submission.notification.threshold

        if threshold is not None and task.score is not None and task.score < threshold:
            return

        q = self.notification_queues.get(note_queue, None)
        if not q:
            self.notification_queues[note_queue] = q = NamedQueue(
                note_queue, self.redis_persist)
        q.push(task.as_primitives())

    def expired(self, delta: float, errors) -> bool:
        if errors:
            return delta >= self.config.core.ingester.incomplete_expire_after_seconds
        else:
            return delta >= self.config.core.ingester.expire_after

    def drop(self, task: IngestTask) -> bool:
        priority = task.params.priority
        sample_threshold = self.config.core.ingester.sampling_at

        dropped = False
        if priority <= _min_priority:
            dropped = True
        else:
            for level, rng in self.priority_range.items():
                if rng[0] <= priority <= rng[1] and level in sample_threshold:
                    dropped = must_drop(self.unique_queue.count(*rng),
                                        sample_threshold[level])
                    break

            if not dropped:
                if task.file_size > self.config.submission.max_file_size or task.file_size == 0:
                    dropped = True

        if task.params.never_drop or not dropped:
            return False

        task.failure = 'Skipped'
        self._notify_drop(task)
        self.counter.increment('skipped')
        return True

    def _notify_drop(self, task: IngestTask):
        self.send_notification(task)

        c12n = task.params.classification
        expiry = now_as_iso(86400)
        sha256 = task.submission.files[0].sha256

        self.datastore.save_or_freshen_file(sha256, {'sha256': sha256},
                                            expiry,
                                            c12n,
                                            redis=self.redis)

    def is_whitelisted(self, task: IngestTask):
        reason, hit = self.get_whitelist_verdict(self.whitelist, task)
        hit = {x: dotdump(safe_str(y)) for x, y in hit.items()}
        sha256 = task.submission.files[0].sha256

        if not reason:
            with self.whitelisted_lock:
                reason = self.whitelisted.get(sha256, None)
                if reason:
                    hit = 'cached'

        if reason:
            if hit != 'cached':
                with self.whitelisted_lock:
                    self.whitelisted[sha256] = reason

            task.failure = "Whitelisting due to reason %s (%s)" % (dotdump(
                safe_str(reason)), hit)
            self._notify_drop(task)

            self.counter.increment('whitelisted')

        return reason

    def submit(self, task: IngestTask):
        self.submit_client.submit(
            submission_obj=task.submission,
            completed_queue=COMPLETE_QUEUE_NAME,
        )

        self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key)
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis"
        )

    def retry(self, task: IngestTask, scan_key: str, ex):
        current_time = now()

        retries = task.retries + 1

        if retries > _max_retries:
            trace = ''
            if ex:
                trace = ': ' + get_stacktrace_info(ex)
            self.log.error(
                f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        elif self.expired(current_time - task.ingest_time.timestamp(), 0):
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        else:
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})'
            )
            task.retries = retries
            self.retry_queue.push(int(now(_retry_delay)), task.as_primitives())

    def finalize(self, psid: str, sid: str, score: float, task: IngestTask):
        self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed")
        if psid:
            task.params.psid = psid
        task.score = score
        task.submission.sid = sid

        selected = task.params.services.selected
        resubmit_to = task.params.services.resubmit

        resubmit_selected = determine_resubmit_selected(selected, resubmit_to)
        will_resubmit = resubmit_selected and should_resubmit(score)
        if will_resubmit:
            task.extended_scan = 'submitted'
            task.params.psid = None

        if self.is_alert(task, score):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Notifying alerter "
                f"to {'update' if task.params.psid else 'create'} an alert")
            self.alert_queue.push(task.as_primitives())

        self.send_notification(task)

        if will_resubmit:
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis"
            )
            task.params.psid = sid
            task.submission.sid = None
            task.submission.scan_key = None
            task.params.services.resubmit = []
            task.params.services.selected = resubmit_selected

            self.unique_queue.push(task.params.priority, task.as_primitives())

    def is_alert(self, task: IngestTask, score: float) -> bool:
        if not task.params.generate_alert:
            return False

        if score < self.config.core.alerter.threshold:
            return False

        return True
class WatcherServer(CoreBase):
    def __init__(self, redis=None, redis_persist=None):
        super().__init__('assemblyline.watcher',
                         redis=redis,
                         redis_persist=redis_persist)

        # Watcher structures
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=self.redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, self.redis_persist)

        # Task management structures
        self.running_tasks = ExpiringHash(
            DISPATCH_RUNNING_TASK_HASH,
            host=self.redis)  # TODO, move to persistant?
        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)

        # Metrics tracking
        self.counter = MetricsFactory(metrics_type='watcher',
                                      schema=Metrics,
                                      name='watcher',
                                      redis=self.redis,
                                      config=self.config)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="watcher")
        else:
            self.apm_client = None

    def try_run(self):
        counter = self.counter
        apm_client = self.apm_client

        while self.running:
            self.heartbeat()

            # Download all messages from the queue that have expired
            seconds, _ = retry_call(self.redis.time)
            messages = self.queue.dequeue_range(0, seconds)

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Try to pass on all the messages to their intended recipient, try not to let
            # the failure of one message from preventing the others from going through
            for key in messages:
                # Start of transaction
                if apm_client:
                    apm_client.begin_transaction('process_messages')

                message = self.hash.pop(key)
                if message:
                    try:
                        if message['action'] == WatcherAction.TimeoutTask:
                            self.cancel_service_task(message['task_key'],
                                                     message['worker'])
                        else:
                            queue = NamedQueue(message['queue'], self.redis)
                            queue.push(message['message'])

                        self.counter.increment('expired')
                        # End of transaction (success)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'success')
                    except Exception as error:
                        # End of transaction (exception)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'error')

                        self.log.exception(error)
                else:
                    # End of transaction (duplicate)
                    if apm_client:
                        apm_client.end_transaction('watch_message',
                                                   'duplicate')

                    self.log.warning(
                        f'Handled watch twice: {key} {len(key)} {type(key)}')

            counter.increment_execution_time('cpu_seconds',
                                             time.process_time() - cpu_mark)
            counter.increment_execution_time('busy_seconds',
                                             time.time() - time_mark)

            if not messages:
                time.sleep(0.1)

    def cancel_service_task(self, task_key, worker):
        # We believe a service task has timed out, try and read it from running tasks
        # If we can't find the task in running tasks, it finished JUST before timing out, let it go
        task = self.running_tasks.pop(task_key)
        if not task:
            return

        # We can confirm that the task is ours now, even if the worker finished, the result will be ignored
        task = Task(task)
        self.log.info(
            f"[{task.sid}] Service {task.service_name} timed out on {task.fileinfo.sha256}."
        )

        # Mark the previous attempt as invalid and redispatch it
        dispatch_table = DispatchHash(task.sid, self.redis)
        dispatch_table.fail_recoverable(task.fileinfo.sha256,
                                        task.service_name)
        dispatch_table.dispatch(task.fileinfo.sha256, task.service_name)
        get_service_queue(task.service_name,
                          self.redis).push(task.priority, task.as_primitives())

        # We push the task of killing the container off on the scaler, which already has root access
        # the scaler can also double check that the service name and container id match, to be sure
        # we aren't accidentally killing the wrong container
        self.scaler_timeout_queue.push({
            'service': task.service_name,
            'container': worker
        })

        # Report to the metrics system that a recoverable error has occurred for that service
        export_metrics_once(task.service_name,
                            ServiceMetrics,
                            dict(fail_recoverable=1),
                            host=worker,
                            counter_type='service')
Exemple #14
0
class Ingester:
    """Internal interface to the ingestion queues."""
    def __init__(self,
                 datastore,
                 logger,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester'):
        self.datastore = datastore
        self.log = logger

        # Cache the user groups
        self.cache_lock = threading.RLock(
        )  # TODO are middle man instances single threaded now?
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Create a config cache that will refresh config values periodically
        self.config = forge.CachedObject(forge.get_config)

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value = constants.PRIORITIES
        self.priority_range = constants.PRIORITY_RANGES
        self.threshold_value = constants.PRIORITY_THRESHOLDS

        # Connect to the redis servers
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.persistent_redis = persistent_redis or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.persistent_redis)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(_completeq_name, self.redis)

        # Internal. Dropped entries are placed on this queue.
        # self.drop_queue = NamedQueue('m-drop', self.persistent_redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME,
                                       self.persistent_redis)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.persistent_redis)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.persistent_redis)

        # Internal, timeout watch queue
        self.timeout_queue = PriorityQueue('m-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.persistent_redis)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)

    def get_groups_from_user(self, username: str) -> List[str]:
        # Reset the group cache at the top of each hour
        if time.time() // HOUR_IN_SECONDS > self._user_groups_reset:
            self._user_groups = {}
            self._user_groups_reset = time.time() // HOUR_IN_SECONDS

        # Get the groups for this user if not known
        if username not in self._user_groups:
            user_data = self.datastore.user.get(username)
            if user_data:
                self._user_groups[username] = user_data.groups
            else:
                self._user_groups[username] = []
        return self._user_groups[username]

    def ingest(self, task: IngestTask):
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Task received for processing"
        )
        # Load a snapshot of ingest parameters as of right now.
        max_file_size = self.config.submission.max_file_size
        param = task.params

        self.counter.increment('bytes_ingested', increment_by=task.file_size)
        self.counter.increment('submissions_ingested')

        if any(len(file.sha256) != 64 for file in task.submission.files):
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped")
            self.send_notification(task,
                                   failure="Invalid sha256",
                                   logfunc=self.log.warning)
            return

        # Clean up metadata strings, since we may delete some, iterate on a copy of the keys
        for key in list(task.submission.metadata.keys()):
            value = task.submission.metadata[key]
            meta_size = len(value)
            if meta_size > self.config.submission.max_metadata_length:
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] '
                    f'Removing {key} from metadata because value is too big')
                task.submission.metadata.pop(key)

        if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop:
            task.failure = f"File too large ({task.file_size} > {max_file_size})"
            self._notify_drop(task)
            self.counter.increment('skipped')
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] {task.failure}")
            return

        # Set the groups from the user, if they aren't already set
        if not task.params.groups:
            task.params.groups = self.get_groups_from_user(
                task.params.submitter)

        # Check if this file is already being processed
        pprevious, previous, score = None, False, None
        if not param.ignore_cache:
            pprevious, previous, score, _ = self.check(task)

        # Assign priority.
        low_priority = self.is_low_priority(task)

        priority = param.priority
        if priority < 0:
            priority = self.priority_value['medium']

            if score is not None:
                priority = self.priority_value['low']
                for level, threshold in self.threshold_value.items():
                    if score >= threshold:
                        priority = self.priority_value[level]
                        break
            elif low_priority:
                priority = self.priority_value['low']

        # Reduce the priority by an order of magnitude for very old files.
        current_time = now()
        if priority and self.expired(
                current_time - task.submission.time.timestamp(), 0):
            priority = (priority / 10) or 1

        param.priority = priority

        # Do this after priority has been assigned.
        # (So we don't end up dropping the resubmission).
        if previous:
            self.counter.increment('duplicates')
            self.finalize(pprevious, previous, score, task)
            return

        if self.drop(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped")
            return

        if self.is_whitelisted(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted")
            return

        self.unique_queue.push(priority, task.as_primitives())

    def check(self, task: IngestTask):
        key = self.stamp_filescore_key(task)

        with self.cache_lock:
            result = self.cache.get(key, None)

        if result:
            self.counter.increment('cache_hit_local')
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Local cache hit')
        else:
            result = self.datastore.filescore.get(key)
            if result:
                self.counter.increment('cache_hit')
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] Remote cache hit')
            else:
                self.counter.increment('cache_miss')
                return None, False, None, key

            with self.cache_lock:
                self.cache[key] = result

        current_time = now()
        age = current_time - result.time
        errors = result.errors

        if self.expired(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired"
            )
            self.counter.increment('cache_expired')
            self.cache.pop(key, None)
            self.datastore.filescore.delete(key)
            return None, False, None, key
        elif self.stale(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale"
            )
            self.counter.increment('cache_stale')
            return None, False, result.score, key

        return result.psid, result.sid, result.score, key

    def stale(self, delta: float, errors: int):
        if errors:
            return delta >= self.config.core.ingester.incomplete_stale_after_seconds
        else:
            return delta >= self.config.core.ingester.stale_after_seconds

    @staticmethod
    def stamp_filescore_key(task: IngestTask, sha256=None):
        if not sha256:
            sha256 = task.submission.files[0].sha256

        key = task.scan_key

        if not key:
            key = task.params.create_filescore_key(sha256)
            task.scan_key = key

        return key

    def completed(self, sub):
        """Invoked when notified that a submission has completed."""
        # There is only one file in the submissions we have made
        sha256 = sub.files[0].sha256
        scan_key = sub.params.create_filescore_key(sha256)
        raw = self.scanning.pop(scan_key)

        psid = sub.params.psid
        score = sub.max_score
        sid = sub.sid

        if not raw:
            # Some other worker has already popped the scanning queue?
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission completed twice")
            return scan_key

        task = IngestTask(raw)
        task.submission.sid = sid

        errors = sub.error_count
        file_count = sub.file_count
        self.counter.increment('submissions_completed')
        self.counter.increment('files_completed', increment_by=file_count)
        self.counter.increment('bytes_completed', increment_by=task.file_size)

        with self.cache_lock:
            fs = self.cache[scan_key] = FileScore({
                'expiry_ts':
                now(self.config.core.ingester.cache_dtl * 24 * 60 * 60),
                'errors':
                errors,
                'psid':
                psid,
                'score':
                score,
                'sid':
                sid,
                'time':
                now(),
            })
            self.datastore.filescore.save(scan_key, fs)

        self.finalize(psid, sid, score, task)

        def exhaust() -> Iterable[IngestTask]:
            while True:
                res = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                               blocking=False)
                if res is None:
                    break
                res = IngestTask(res)
                res.submission.sid = sid
                yield res

        # You may be tempted to remove the assignment to dups and use the
        # value directly in the for loop below. That would be a mistake.
        # The function finalize may push on the duplicate queue which we
        # are pulling off and so condensing those two lines creates a
        # potential infinite loop.
        dups = [dup for dup in exhaust()]
        for dup in dups:
            self.finalize(psid, sid, score, dup)

        return scan_key

    def send_notification(self, task: IngestTask, failure=None, logfunc=None):
        if logfunc is None:
            logfunc = self.log.info

        if failure:
            task.failure = failure

        failure = task.failure
        if failure:
            logfunc("%s: %s", failure, str(task.json()))

        if not task.submission.notification.queue:
            return

        note_queue = _notification_queue_prefix + task.submission.notification.queue
        threshold = task.submission.notification.threshold

        if threshold is not None and task.score is not None and task.score < threshold:
            return

        q = self.notification_queues.get(note_queue, None)
        if not q:
            self.notification_queues[note_queue] = q = NamedQueue(
                note_queue, self.persistent_redis)
        q.push(task.as_primitives())

    def expired(self, delta: float, errors) -> bool:
        if errors:
            return delta >= self.config.core.ingester.incomplete_expire_after_seconds
        else:
            return delta >= self.config.core.ingester.expire_after

    def drop(self, task: IngestTask) -> bool:
        priority = task.params.priority
        sample_threshold = self.config.core.ingester.sampling_at

        dropped = False
        if priority <= _min_priority:
            dropped = True
        else:
            for level, rng in self.priority_range.items():
                if rng[0] <= priority <= rng[1] and level in sample_threshold:
                    dropped = must_drop(self.unique_queue.count(*rng),
                                        sample_threshold[level])
                    break

            if not dropped:
                if task.file_size > self.config.submission.max_file_size or task.file_size == 0:
                    dropped = True

        if task.params.never_drop or not dropped:
            return False

        task.failure = 'Skipped'
        self._notify_drop(task)
        self.counter.increment('skipped')
        return True

    def _notify_drop(self, task: IngestTask):
        self.send_notification(task)

        c12n = task.params.classification
        expiry = now_as_iso(86400)
        sha256 = task.submission.files[0].sha256

        self.datastore.save_or_freshen_file(sha256, {'sha256': sha256},
                                            expiry,
                                            c12n,
                                            redis=self.redis)

    def is_whitelisted(self, task: IngestTask):
        reason, hit = self.get_whitelist_verdict(self.whitelist, task)
        hit = {x: dotdump(safe_str(y)) for x, y in hit.items()}
        sha256 = task.submission.files[0].sha256

        if not reason:
            with self.whitelisted_lock:
                reason = self.whitelisted.get(sha256, None)
                if reason:
                    hit = 'cached'

        if reason:
            if hit != 'cached':
                with self.whitelisted_lock:
                    self.whitelisted[sha256] = reason

            task.failure = "Whitelisting due to reason %s (%s)" % (dotdump(
                safe_str(reason)), hit)
            self._notify_drop(task)

            self.counter.increment('whitelisted')

        return reason

    def submit(self, task: IngestTask):
        self.submit_client.submit(
            submission_obj=task.submission,
            completed_queue=_completeq_name,
        )

        self.timeout_queue.push(int(now(_max_time)), task.scan_key)
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis"
        )

    def retry(self, task, scan_key, ex):
        current_time = now()

        retries = task.retries + 1

        if retries > _max_retries:
            trace = ''
            if ex:
                trace = ': ' + get_stacktrace_info(ex)
            self.log.error(
                f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        elif self.expired(current_time - task.ingest_time.timestamp(), 0):
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        else:
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})'
            )
            task.retries = retries
            self.retry_queue.push(int(now(_retry_delay)), task.json())

    def finalize(self, psid, sid, score, task: IngestTask):
        self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed")
        if psid:
            task.params.psid = psid
        task.score = score
        task.submission.sid = sid

        selected = task.params.services.selected
        resubmit_to = task.params.services.resubmit

        resubmit_selected = determine_resubmit_selected(selected, resubmit_to)
        will_resubmit = resubmit_selected and should_resubmit(score)
        if will_resubmit:
            task.extended_scan = 'submitted'
            task.params.psid = None

        if self.is_alert(task, score):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Notifying alerter "
                f"to {'update' if will_resubmit else 'create'} an alert")
            self.alert_queue.push(task.as_primitives())

        self.send_notification(task)

        if will_resubmit:
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis"
            )
            task.params.psid = sid
            task.submission.sid = None
            task.params.services.resubmit = []
            task.scan_key = None
            task.params.services.selected = resubmit_selected

            self.unique_queue.push(task.params.priority, task.as_primitives())

    def is_alert(self, task: IngestTask, score):
        if not task.params.generate_alert:
            return False

        if score < self.threshold_value['critical']:
            return False

        return True
class DispatchClient:
    def __init__(self,
                 datastore=None,
                 redis=None,
                 redis_persist=None,
                 logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.timeout_watcher = WatcherClient(redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger(
            "assemblyline.dispatching.client")
        self.results = datastore.result
        self.errors = datastore.error
        self.files = datastore.file
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH,
                                               host=redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH,
                                          host=self.redis)
        self.service_data = cast(Dict[str, Service],
                                 CachedObject(self._get_services))

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def dispatch_submission(self,
                            submission: Submission,
                            completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(
            SubmissionTask(
                dict(
                    submission=submission,
                    completed_queue=completed_queue,
                )).as_primitives())

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        all_service_status = dispatch_hash.all_results()
        all_files = dispatch_hash.all_files()
        self.log.info(
            f"[{sid}] Listing outstanding services {len(all_files)} files "
            f"and {len(all_service_status)} entries found")

        output: Dict[str, int] = {}

        for file_hash in all_files:
            # The schedule might not be in the cache if the submission or file was just issued,
            # but it will be as soon as the file passes through the dispatcher
            schedule = dispatch_hash.schedules.get(file_hash)
            status_values = all_service_status.get(file_hash, {})

            # Go through the schedule stage by stage so we can react to drops
            # either we have a result and we don't need to count the file (but might drop it)
            # or we don't have a result, and we need to count that file
            while schedule:
                stage = schedule.pop(0)
                for service_name in stage:
                    status = status_values.get(service_name)
                    if status:
                        if status.drop:
                            schedule.clear()
                    else:
                        output[service_name] = output.get(service_name, 0) + 1

        return output

    def request_work(self,
                     worker_id,
                     service_name,
                     service_version,
                     timeout: float = 60,
                     blocking=True) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            try:
                return self._request_work(worker_id,
                                          service_name,
                                          service_version,
                                          blocking=blocking,
                                          timeout=remaining)
            except RetryRequestWork:
                remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version, timeout,
                      blocking) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(
                f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout))
        else:
            result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(
                f"{service_name}:{worker_id} no task returned: [empty message]"
            )
            return None
        task = ServiceTask(result)

        # If someone is supposed to be working on this task right now, we won't be able to add it
        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found"
            )

            process_table = DispatchHash(task.sid, self.redis)

            abandoned = process_table.dispatch_time(
                file_hash=task.fileinfo.sha256, service=task.service_name) == 0
            finished = process_table.finished(
                file_hash=task.fileinfo.sha256,
                service=task.service_name) is not None

            # A service might be re-dispatched as it finishes, when that is the case it can be marked as
            # both finished and dispatched, if that is the case, drop the dispatch from the table
            if finished and not abandoned:
                process_table.drop_dispatch(file_hash=task.fileinfo.sha256,
                                            service=task.service_name)

            if abandoned or finished:
                self.log.info(
                    f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete"
                )
                self.running_tasks.pop(task.key())
                raise RetryRequestWork()

            # Check if this task has reached the retry limit
            attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}',
                                          host=self.redis)
            total_attempts = attempt_record.increment(task.key())
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} "
                f"task attempt {total_attempts}/3")
            if total_attempts > 3:
                self.log.warning(
                    f"[{task.sid}/{task.fileinfo.sha256}] "
                    f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED "
                )
                error = Error(
                    dict(
                        archive_ts=now_as_iso(
                            self.config.datastore.ilm.days_until_archive * 24 *
                            60 * 60),
                        created='NOW',
                        expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                             60) if task.ttl else None,
                        response=dict(
                            message=
                            f'The number of retries has passed the limit.',
                            service_name=task.service_name,
                            service_version=service_version,
                            status='FAIL_NONRECOVERABLE',
                        ),
                        sha256=task.fileinfo.sha256,
                        type="TASK PRE-EMPTED",
                    ))
                error_key = error.build_key(task=task)
                self.service_failed(task.sid, error_key, error)
                export_metrics_once(service_name,
                                    Metrics,
                                    dict(fail_nonrecoverable=1),
                                    host=worker_id,
                                    counter_type='service')
                raise RetryRequestWork()

            # Get the service information
            service_data = self.service_data[task.service_name]
            self.timeout_watcher.touch_task(timeout=int(service_data.timeout),
                                            key=f'{task.sid}-{task.key()}',
                                            worker=worker_id,
                                            task_key=task.key())
            return task
        raise RetryRequestWork()

    def _dispatching_error(self, task, process_table, error):
        error_key = error.build_key(task=task)
        if process_table.add_error(error_key):
            self.errors.save(error_key, error)
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)

    def service_finished(self,
                         sid: str,
                         result_key: str,
                         result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=result.response.service_name,
            sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing successful results."
            )
            return
        task = ServiceTask(task)

        # Check if the service is a candidate for dynamic recursion prevention
        if not task.ignore_dynamic_recursion_prevention:
            service_info = self.service_data.get(result.response.service_name,
                                                 None)
            if service_info and service_info.category == "Dynamic Analysis":
                # TODO: This should be done in lua because it can introduce race condition in the future
                #       but in the meantime it will remain this way while we can confirm it work as expected
                submission = self.active_submissions.get(sid)
                submission['submission']['params']['services'][
                    'runtime_excluded'].append(result.response.service_name)
                self.active_submissions.set(sid, submission)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key,
                                     {"expiry_ts": result.archive_ts})
        else:
            with Lock(f"lock-{result_key}", 5, self.redis):
                old = self.ds.result.get(result_key)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                self.ds.result.save(result_key, result)

        # Let the logs know we have received a result for this task
        if result.drop_file:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key} but processing will stop after this service."
            )
        else:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key}")

        # Store the result object and mark the service finished in the global table
        process_table = DispatchHash(task.sid, self.redis)
        remaining, duplicate = process_table.finish(
            task.fileinfo.sha256, task.service_name, result_key,
            result.result.score, result.classification, result.drop_file)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')
        if duplicate:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already "
                f"completed in the global processing table.")
            return

        # Push the result tags into redis
        new_tags = []
        for section in result.result.sections:
            new_tags.extend(tag_dict_to_list(section.tags.as_primitives()))
        if new_tags:
            tag_set = ExpiringSet(get_tag_set_name(
                sid=task.sid, file_hash=task.fileinfo.sha256),
                                  host=self.redis)
            tag_set.add(*new_tags)

        # Update the temporary data table for this file
        temp_data_hash = ExpiringHash(get_temporary_submission_data_name(
            sid=task.sid, file_hash=task.fileinfo.sha256),
                                      host=self.redis)
        for key, value in (temporary_data or {}).items():
            temp_data_hash.set(key, value)

        # Send the extracted files to the dispatcher
        depth_limit = self.config.submission.max_extraction_depth
        new_depth = task.depth + 1
        if new_depth < depth_limit:
            # Prepare the temporary data from the parent to build the temporary data table for
            # these newly extract files
            parent_data = dict(temp_data_hash.items())

            for extracted_data in result.response.extracted:
                if not process_table.add_file(
                        extracted_data.sha256,
                        task.max_files,
                        parent_hash=task.fileinfo.sha256):
                    if parent_data:
                        child_hash_name = get_temporary_submission_data_name(
                            task.sid, extracted_data.sha256)
                        ExpiringHash(child_hash_name,
                                     host=self.redis).multi_set(parent_data)

                    self._dispatching_error(
                        task, process_table,
                        Error({
                            'archive_ts': result.archive_ts,
                            'expiry_ts': result.expiry_ts,
                            'response': {
                                'message':
                                f"Too many files extracted for submission {task.sid} "
                                f"{extracted_data.sha256} extracted by "
                                f"{task.service_name} will be dropped",
                                'service_name':
                                task.service_name,
                                'service_tool_version':
                                result.response.service_tool_version,
                                'service_version':
                                result.response.service_version,
                                'status':
                                'FAIL_NONRECOVERABLE'
                            },
                            'sha256': extracted_data.sha256,
                            'type': 'MAX FILES REACHED'
                        }))
                    continue
                file_data = self.files.get(extracted_data.sha256)
                self.file_queue.push(
                    FileTask(
                        dict(sid=task.sid,
                             min_classification=task.min_classification.max(
                                 extracted_data.classification).value,
                             file_info=dict(
                                 magic=file_data.magic,
                                 md5=file_data.md5,
                                 mime=file_data.mime,
                                 sha1=file_data.sha1,
                                 sha256=file_data.sha256,
                                 size=file_data.size,
                                 type=file_data.type,
                             ),
                             depth=new_depth,
                             parent_hash=task.fileinfo.sha256,
                             max_files=task.max_files)).as_primitives())
        else:
            for extracted_data in result.response.extracted:
                self._dispatching_error(
                    task, process_table,
                    Error({
                        'archive_ts': result.archive_ts,
                        'expiry_ts': result.expiry_ts,
                        'response': {
                            'message':
                            f"{task.service_name} has extracted a file "
                            f"{extracted_data.sha256} beyond the depth limits",
                            'service_name':
                            result.response.service_name,
                            'service_tool_version':
                            result.response.service_tool_version,
                            'service_version':
                            result.response.service_version,
                            'status':
                            'FAIL_NONRECOVERABLE'
                        },
                        'sha256': extracted_data.sha256,
                        'type': 'MAX DEPTH REACHED'
                    }))

        # If the global table said that this was the last outstanding service,
        # send a message to the dispatchers.
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification.value,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w).push(msg)

    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=error.response.service_name,
            sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(
            f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error."
        )
        remaining = -1
        # Mark the attempt to process the file over in the dispatch table
        process_table = DispatchHash(task.sid, self.redis)
        if error.response.status == "FAIL_RECOVERABLE":
            # Because the error is recoverable, we will not save it nor we will notify the user
            process_table.fail_recoverable(task.fileinfo.sha256,
                                           task.service_name)
        else:
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            remaining, _duplicate = process_table.fail_nonrecoverable(
                task.fileinfo.sha256, task.service_name, error_key)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')

        # Send a message to prompt the re-issue of the task if needed
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

    def setup_watch_queue(self, sid):
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        # Create a unique queue
        queue_name = reply_queue_name(prefix="D", suffix="WQ")
        watch_queue = NamedQueue(queue_name, ttl=30)
        watch_queue.push(
            WatchQueueMessage({
                'status': 'START'
            }).as_primitives())

        # Add the newly created queue to the list of queues for the given submission
        self._get_watcher_list(sid).add(queue_name)

        # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute)
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        if dispatch_hash.dispatch_count(
        ) == 0 and dispatch_hash.finished_count() == 0:
            # This table is empty? do we have this submission at all?
            submission = self.ds.submission.get(sid)
            if not submission or submission.state == 'completed':
                watch_queue.push(
                    WatchQueueMessage({
                        "status": "STOP"
                    }).as_primitives())
            else:
                # We do have a submission, remind the dispatcher to work on it
                self.submission_queue.push({'sid': sid})

        else:
            all_service_status = dispatch_hash.all_results()
            for status_values in all_service_status.values():
                for status in status_values.values():
                    if status.is_error:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "FAIL",
                                "cache_key": status.key
                            }).as_primitives())
                    else:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "OK",
                                "cache_key": status.key
                            }).as_primitives())

        return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
class Dispatcher:

    def __init__(self, datastore, redis, redis_persist, logger, counter_name='dispatcher'):
        # Load the datastore collections that we are going to be using
        self.datastore: AssemblylineDatastore = datastore
        self.log: logging.Logger = logger
        self.submissions: Collection = datastore.submission
        self.results: Collection = datastore.result
        self.errors: Collection = datastore.error
        self.files: Collection = datastore.file

        # Create a config cache that will refresh config values periodically
        self.config: Config = forge.get_config()

        # Connect to all of our persistent redis structures
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Build some utility classes
        self.scheduler = Scheduler(datastore, self.config, self.redis)
        self.classification_engine = forge.get_classification()
        self.timeout_watcher = WatcherClient(self.redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self._nonper_other_queues = {}
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)

        # Publish counters to the metrics sink.
        self.counter = MetricsFactory(metrics_type='dispatcher', schema=Metrics, name=counter_name,
                                      redis=self.redis, config=self.config)

    def volatile_named_queue(self, name: str) -> NamedQueue:
        if name not in self._nonper_other_queues:
            self._nonper_other_queues[name] = NamedQueue(name, self.redis)
        return self._nonper_other_queues[name]

    def dispatch_submission(self, task: SubmissionTask):
        """
        Find any files associated with a submission and dispatch them if they are
        not marked as in progress. If all files are finished, finalize the submission.

        This version of dispatch submission doesn't verify each result, but assumes that
        the dispatch table has been kept up to date by other components.

        Preconditions:
            - File exists in the filestore and file collection in the datastore
            - Submission is stored in the datastore
        """
        submission = task.submission
        sid = submission.sid

        if not self.active_submissions.exists(sid):
            self.log.info(f"[{sid}] New submission received")
            self.active_submissions.add(sid, task.as_primitives())
        else:
            self.log.info(f"[{sid}] Received a pre-existing submission, check if it is complete")

        # Refresh the watch, this ensures that this function will be called again
        # if something goes wrong with one of the files, and it never gets invoked by dispatch_file.
        self.timeout_watcher.touch(key=sid, timeout=int(self.config.core.dispatcher.timeout),
                                   queue=SUBMISSION_QUEUE, message={'sid': sid})

        # Refresh the quota hold
        if submission.params.quota_item and submission.params.submitter:
            self.log.info(f"[{sid}] Submission will count towards {submission.params.submitter.upper()} quota")
            Hash('submissions-' + submission.params.submitter, self.redis_persist).add(sid, isotime.now_as_iso())

        # Open up the file/service table for this submission
        dispatch_table = DispatchHash(submission.sid, self.redis, fetch_results=True)
        file_parents = dispatch_table.file_tree()  # Load the file tree data as well

        # All the submission files, and all the file_tree files, to be sure we don't miss any incomplete children
        unchecked_hashes = [submission_file.sha256 for submission_file in submission.files]
        unchecked_hashes = list(set(unchecked_hashes) | set(file_parents.keys()))

        # Using the file tree we can recalculate the depth of any file
        depth_limit = self.config.submission.max_extraction_depth
        file_depth = depths_from_tree(file_parents)

        # Try to find all files, and extracted files, and create task objects for them
        # (we will need the file data anyway for checking the schedule later)
        max_files = len(submission.files) + submission.params.max_extracted
        unchecked_files = []  # Files that haven't been checked yet
        try:
            for sha, file_data in self.files.multiget(unchecked_hashes).items():
                unchecked_files.append(FileTask(dict(
                    sid=sid,
                    min_classification=task.submission.classification,
                    file_info=dict(
                        magic=file_data.magic,
                        md5=file_data.md5,
                        mime=file_data.mime,
                        sha1=file_data.sha1,
                        sha256=file_data.sha256,
                        size=file_data.size,
                        type=file_data.type,
                    ),
                    depth=file_depth.get(sha, 0),
                    max_files=max_files
                )))
        except MultiKeyError as missing:
            errors = []
            for file_sha in missing.keys:
                error = Error(dict(
                    archive_ts=submission.archive_ts,
                    expiry_ts=submission.expiry_ts,
                    response=dict(
                        message="Submission couldn't be completed due to missing file.",
                        service_name="dispatcher",
                        service_tool_version='4',
                        service_version='4',
                        status="FAIL_NONRECOVERABLE",
                    ),
                    sha256=file_sha,
                    type='UNKNOWN'
                ))
                error_key = error.build_key(service_tool_version=sid)
                self.datastore.error.save(error_key, error)
                errors.append(error_key)
            return self.cancel_submission(task, errors, file_parents)

        # Files that have already been encountered, but may or may not have been processed yet
        # encountered_files = {file.sha256 for file in submission.files}
        pending_files = {}  # Files that have not yet been processed

        # Track information about the results as we hit them
        file_scores: Dict[str, int] = {}

        # # Load the current state of the dispatch table in one go rather than one at a time in the loop
        prior_dispatches = dispatch_table.all_dispatches()

        # found should be added to the unchecked files if they haven't been encountered already
        for file_task in unchecked_files:
            sha = file_task.file_info.sha256
            schedule = self.build_schedule(dispatch_table, submission, sha, file_task.file_info.type)

            while schedule:
                stage = schedule.pop(0)
                for service_name in stage:
                    # Only active services should be in this dict, so if a service that was placed in the
                    # schedule is now missing it has been disabled or taken offline.
                    service = self.scheduler.services.get(service_name)
                    if not service:
                        continue

                    # If the service is still marked as 'in progress'
                    runtime = time.time() - prior_dispatches.get(sha, {}).get(service_name, 0)
                    if runtime < service.timeout:
                        pending_files[sha] = file_task
                        continue

                    # It hasn't started, has timed out, or is finished, see if we have a result
                    result_row = dispatch_table.finished(sha, service_name)

                    # No result found, mark the file as incomplete
                    if not result_row:
                        pending_files[sha] = file_task
                        continue

                    if not submission.params.ignore_filtering and result_row.drop:
                        schedule.clear()

                    # The process table is marked that a service has been abandoned due to errors
                    if result_row.is_error:
                        continue

                    # Collect information about the result
                    file_scores[sha] = file_scores.get(sha, 0) + result_row.score

        # Using the file tree find the most shallow parent of the given file
        def lowest_parent(_sha):
            # A root file won't have any parents in the dict
            if _sha not in file_parents or None in file_parents[_sha]:
                return None
            return min((file_depth.get(parent, depth_limit), parent) for parent in file_parents[_sha])[1]

        # Filter out things over the depth limit
        pending_files = {sha: ft for sha, ft in pending_files.items() if ft.depth < depth_limit}

        # Filter out files based on the extraction limits
        pending_files = {sha: ft for sha, ft in pending_files.items()
                         if dispatch_table.add_file(sha, max_files, lowest_parent(sha))}

        # If there are pending files, then at least one service, on at least one
        # file isn't done yet, and hasn't been filtered by any of the previous few steps
        # poke those files
        if pending_files:
            self.log.debug(f"[{sid}] Dispatching {len(pending_files)} files: {list(pending_files.keys())}")
            for file_task in pending_files.values():
                self.file_queue.push(file_task.as_primitives())
        else:
            self.log.debug(f"[{sid}] Finalizing submission.")
            max_score = max(file_scores.values()) if file_scores else 0  # Submissions with no results have no score
            self.finalize_submission(task, max_score, file_scores.keys())

    def _cleanup_submission(self, task: SubmissionTask, file_list: List[str]):
        """Clean up code that is the same for canceled and finished submissions"""
        submission = task.submission
        sid = submission.sid

        # Erase the temporary data which may have accumulated during processing
        for file_hash in file_list:
            hash_name = get_temporary_submission_data_name(sid, file_hash=file_hash)
            ExpiringHash(hash_name, host=self.redis).delete()

        if submission.params.quota_item and submission.params.submitter:
            self.log.info(f"[{sid}] Submission no longer counts toward {submission.params.submitter.upper()} quota")
            Hash('submissions-' + submission.params.submitter, self.redis_persist).pop(sid)

        if task.completed_queue:
            self.volatile_named_queue(task.completed_queue).push(submission.as_primitives())

        # Send complete message to any watchers.
        watcher_list = ExpiringSet(make_watcher_list_name(sid), host=self.redis)
        for w in watcher_list.members():
            NamedQueue(w).push(WatchQueueMessage({'status': 'STOP'}).as_primitives())

        # Clear the timeout watcher
        watcher_list.delete()
        self.timeout_watcher.clear(sid)
        self.active_submissions.pop(sid)

        # Count the submission as 'complete' either way
        self.counter.increment('submissions_completed')

    def cancel_submission(self, task: SubmissionTask, errors, file_list):
        """The submission is being abandoned, delete everything, write failed state."""
        submission = task.submission
        sid = submission.sid

        # Pull down the dispatch table and clear it from redis
        dispatch_table = DispatchHash(submission.sid, self.redis)
        dispatch_table.delete()

        submission.classification = submission.params.classification
        submission.error_count = len(errors)
        submission.errors = errors
        submission.state = 'failed'
        submission.times.completed = isotime.now_as_iso()
        self.submissions.save(sid, submission)

        self._cleanup_submission(task, file_list)
        self.log.error(f"[{sid}] Failed")

    def finalize_submission(self, task: SubmissionTask, max_score, file_list):
        """All of the services for all of the files in this submission have finished or failed.

        Update the records in the datastore, and flush the working data from redis.
        """
        submission = task.submission
        sid = submission.sid

        # Pull down the dispatch table and clear it from redis
        dispatch_table = DispatchHash(submission.sid, self.redis)
        all_results = dispatch_table.all_results()
        errors = dispatch_table.all_extra_errors()
        dispatch_table.delete()

        # Sort the errors out of the results
        results = []
        for row in all_results.values():
            for status in row.values():
                if status.is_error:
                    errors.append(status.key)
                elif status.bucket == 'result':
                    results.append(status.key)
                else:
                    self.log.warning(f"[{sid}] Unexpected service output bucket: {status.bucket}/{status.key}")

        submission.classification = submission.params.classification
        submission.error_count = len(errors)
        submission.errors = errors
        submission.file_count = len(file_list)
        submission.results = results
        submission.max_score = max_score
        submission.state = 'completed'
        submission.times.completed = isotime.now_as_iso()
        self.submissions.save(sid, submission)

        self._cleanup_submission(task, file_list)
        self.log.info(f"[{sid}] Completed; files: {len(file_list)} results: {len(results)} "
                      f"errors: {len(errors)} score: {max_score}")

    def dispatch_file(self, task: FileTask):
        """ Handle a message describing a file to be processed.

        This file may be:
            - A new submission or extracted file.
            - A file that has just completed a stage of processing.
            - A file that has not completed a a stage of processing, but this
              call has been triggered by a timeout or similar.

        If the file is totally new, we will setup a dispatch table, and fill it in.

        Once we make/load a dispatch table, we will dispatch whichever group the table
        shows us hasn't been completed yet.

        When we dispatch to a service, we check if the task is already in the dispatch
        queue. If it isn't proceed normally. If it is, check that the service is still online.
        """
        # Read the message content
        file_hash = task.file_info.sha256
        active_task = self.active_submissions.get(task.sid)

        if active_task is None:
            self.log.warning(f"[{task.sid}] Untracked submission is being processed")
            return

        submission_task = SubmissionTask(active_task)
        submission = submission_task.submission

        # Refresh the watch on the submission, we are still working on it
        self.timeout_watcher.touch(key=task.sid, timeout=int(self.config.core.dispatcher.timeout),
                                   queue=SUBMISSION_QUEUE, message={'sid': task.sid})

        # Open up the file/service table for this submission
        dispatch_table = DispatchHash(task.sid, self.redis, fetch_results=True)

        # Load things that we will need to fill out the
        file_tags = ExpiringSet(task.get_tag_set_name(), host=self.redis)
        file_tags_data = file_tags.members()
        temporary_submission_data = ExpiringHash(task.get_temporary_submission_data_name(), host=self.redis)
        temporary_data = [dict(name=row[0], value=row[1]) for row in temporary_submission_data.items().items()]

        # Calculate the schedule for the file
        schedule = self.build_schedule(dispatch_table, submission, file_hash, task.file_info.type)
        started_stages = []

        # Go through each round of the schedule removing complete/failed services
        # Break when we find a stage that still needs processing
        outstanding = {}
        score = 0
        errors = 0
        while schedule and not outstanding:
            stage = schedule.pop(0)
            started_stages.append(stage)

            for service_name in stage:
                service = self.scheduler.services.get(service_name)
                if not service:
                    continue

                # Load the results, if there are no results, then the service must be dispatched later
                # Don't look at if it has been dispatched, as multiple dispatches are fine,
                # but missing a dispatch isn't.
                finished = dispatch_table.finished(file_hash, service_name)
                if not finished:
                    outstanding[service_name] = service
                    continue

                # If the service terminated in an error, count the error and continue
                if finished.is_error:
                    errors += 1
                    continue

                # if the service finished, count the score, and check if the file has been dropped
                score += finished.score
                if not submission.params.ignore_filtering and finished.drop:
                    schedule.clear()
                    if schedule:  # If there are still stages in the schedule, over write them for next time
                        dispatch_table.schedules.set(file_hash, started_stages)

        # Try to retry/dispatch any outstanding services
        if outstanding:
            self.log.info(f"[{task.sid}] File {file_hash} sent to services : {', '.join(list(outstanding.keys()))}")

            for service_name, service in outstanding.items():

                # Find the actual file name from the list of files in submission
                filename = None
                for file in submission.files:
                    if task.file_info.sha256 == file.sha256:
                        filename = file.name
                        break

                # Build the actual service dispatch message
                config = self.build_service_config(service, submission)
                service_task = ServiceTask(dict(
                    sid=task.sid,
                    metadata=submission.metadata,
                    min_classification=task.min_classification,
                    service_name=service_name,
                    service_config=config,
                    fileinfo=task.file_info,
                    filename=filename or task.file_info.sha256,
                    depth=task.depth,
                    max_files=task.max_files,
                    ttl=submission.params.ttl,
                    ignore_cache=submission.params.ignore_cache,
                    ignore_dynamic_recursion_prevention=submission.params.ignore_dynamic_recursion_prevention,
                    tags=file_tags_data,
                    temporary_submission_data=temporary_data,
                    deep_scan=submission.params.deep_scan,
                    priority=submission.params.priority,
                ))
                dispatch_table.dispatch(file_hash, service_name)
                queue = get_service_queue(service_name, self.redis)
                queue.push(service_task.priority, service_task.as_primitives())

        else:
            # There are no outstanding services, this file is done
            # clean up the tags
            file_tags.delete()

            # If there are no outstanding ANYTHING for this submission,
            # send a message to the submission dispatcher to finalize
            self.counter.increment('files_completed')
            if dispatch_table.all_finished():
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}' starting submission finalization.")
                self.submission_queue.push({'sid': submission.sid})
            else:
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}'. Other files are not finished.")

    def build_schedule(self, dispatch_hash: DispatchHash, submission: Submission,
                       file_hash: str, file_type: str) -> List[List[str]]:
        """Rather than rebuilding the schedule every time we see a file, build it once and cache in redis."""
        cached_schedule = dispatch_hash.schedules.get(file_hash)
        if not cached_schedule:
            # Get the schedule for that file type based on the submission parameters
            obj_schedule = self.scheduler.build_schedule(submission, file_type)
            # The schedule built by the scheduling tool has the service objects, we just want the names for now
            cached_schedule = [list(stage.keys()) for stage in obj_schedule]
            dispatch_hash.schedules.add(file_hash, cached_schedule)
        return cached_schedule

    @classmethod
    def build_service_config(cls, service: Service, submission: Submission) -> Dict[str, str]:
        """Prepare the service config that will be used downstream.

        v3 names: get_service_params get_config_data
        """
        # Load the default service config
        params = {x.name: x.default for x in service.submission_params}

        # Over write it with values from the submission
        if service.name in submission.params.service_spec:
            params.update(submission.params.service_spec[service.name])
        return params
Exemple #17
0
class Alerter(ServerBase):
    def __init__(self):
        super().__init__('assemblyline.alerter')
        # Publish counters to the metrics sink.
        self.counter = MetricsFactory('alerter', Metrics)
        self.datastore = forge.get_datastore(self.config)
        self.persistent_redis = get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )
        self.process_alert_message = forge.get_process_alert_message()
        self.running = False

        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis)
        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="alerter")
        else:
            self.apm_client = None

    def close(self):
        if self.counter:
            self.counter.stop()

        if self.apm_client:
            elasticapm.uninstrument()

    def run_once(self):
        alert = self.alert_queue.pop(timeout=1)
        if not alert:
            return

        # Start of process alert transaction
        if self.apm_client:
            self.apm_client.begin_transaction('Process alert message')

        self.counter.increment('received')
        try:
            alert_type = self.process_alert_message(self.counter,
                                                    self.datastore, self.log,
                                                    alert)

            # End of process alert transaction (success)
            if self.apm_client:
                self.apm_client.end_transaction(alert_type, 'success')

            return alert_type
        except Exception as ex:  # pylint: disable=W0703
            retries = alert['alert_retries'] = alert.get('alert_retries',
                                                         0) + 1
            if retries > MAX_RETRIES:
                self.log.exception(f'Max retries exceeded for: {alert}')
            else:
                self.alert_queue.push(alert)
                if 'Submission not finalized' not in str(ex):
                    self.log.exception(
                        f'Unhandled exception processing: {alert}')

            # End of process alert transaction (failure)
            if self.apm_client:
                self.apm_client.end_transaction('unknown', 'exception')

    def try_run(self):
        while self.running:
            self.heartbeat()
            self.run_once()
Exemple #18
0
class DistributedBackup(object):
    def __init__(self,
                 working_dir,
                 worker_count=50,
                 spawn_workers=True,
                 use_threading=False,
                 logger=None):
        self.working_dir = working_dir
        self.datastore = forge.get_datastore(archive_access=True)
        self.logger = logger
        self.plist = []
        self.use_threading = use_threading
        self.instance_id = get_random_id()
        self.worker_queue = NamedQueue(f"r-worker-{self.instance_id}",
                                       ttl=1800)
        self.done_queue = NamedQueue(f"r-done-{self.instance_id}", ttl=1800)
        self.hash_queue = Hash(f"r-hash-{self.instance_id}")
        self.bucket_error = []
        self.VALID_BUCKETS = sorted(list(
            self.datastore.ds.get_models().keys()))
        self.worker_count = worker_count
        self.spawn_workers = spawn_workers
        self.total_count = 0
        self.error_map_count = {}
        self.missing_map_count = {}
        self.map_count = {}
        self.last_time = 0
        self.last_count = 0
        self.error_count = 0

    def cleanup(self):
        self.worker_queue.delete()
        self.done_queue.delete()
        self.hash_queue.delete()
        for p in self.plist:
            p.terminate()

    def done_thread(self, title):
        t0 = time.time()
        self.last_time = t0

        running_threads = self.worker_count

        while running_threads > 0:
            msg = self.done_queue.pop(timeout=1)

            if msg is None:
                continue

            if "stopped" in msg:
                running_threads -= 1
                continue

            bucket_name = msg.get('bucket_name', 'unknown')

            if msg.get('success', False):
                self.total_count += 1

                if msg.get("missing", False):
                    if bucket_name not in self.missing_map_count:
                        self.missing_map_count[bucket_name] = 0

                    self.missing_map_count[bucket_name] += 1
                else:
                    if bucket_name not in self.map_count:
                        self.map_count[bucket_name] = 0

                    self.map_count[bucket_name] += 1

                new_t = time.time()
                if (new_t - self.last_time) > 5:
                    if self.logger:
                        self.logger.info(
                            "%s (%s at %s keys/sec) ==> %s" %
                            (self.total_count, new_t - self.last_time,
                             int((self.total_count - self.last_count) /
                                 (new_t - self.last_time)), self.map_count))
                    self.last_count = self.total_count
                    self.last_time = new_t
            else:
                self.error_count += 1

                if bucket_name not in self.error_map_count:
                    self.error_map_count[bucket_name] = 0

                self.error_map_count[bucket_name] += 1

        # Cleanup
        self.cleanup()

        summary = ""
        summary += "\n########################\n"
        summary += "####### SUMMARY  #######\n"
        summary += "########################\n"
        summary += "%s items - %s errors - %s secs\n\n" % \
                   (self.total_count, self.error_count, time.time() - t0)

        for k, v in self.map_count.items():
            summary += "\t%15s: %s\n" % (k.upper(), v)

        if len(self.missing_map_count.keys()) > 0:
            summary += "\n\nMissing data:\n\n"
            for k, v in self.missing_map_count.items():
                summary += "\t%15s: %s\n" % (k.upper(), v)

        if len(self.error_map_count.keys()) > 0:
            summary += "\n\nErrors:\n\n"
            for k, v in self.error_map_count.items():
                summary += "\t%15s: %s\n" % (k.upper(), v)

        if len(self.bucket_error) > 0:
            summary += f"\nThese buckets failed to {title.lower()} completely: {self.bucket_error}\n"
        if self.logger:
            self.logger.info(summary)

    # noinspection PyBroadException,PyProtectedMember
    def backup(self, bucket_list, follow_keys=False, query=None):
        if query is None:
            query = 'id:*'

        for bucket in bucket_list:
            if bucket not in self.VALID_BUCKETS:
                if self.logger:
                    self.logger.warn(
                        "\n%s is not a valid bucket.\n\n"
                        "The list of valid buckets is the following:\n\n\t%s\n"
                        % (bucket.upper(), "\n\t".join(self.VALID_BUCKETS)))
                return

        targets = ', '.join(bucket_list)
        try:
            if self.logger:
                self.logger.info("\n-----------------------")
                self.logger.info("----- Data Backup -----")
                self.logger.info("-----------------------")
                self.logger.info(f"    Deep: {follow_keys}")
                self.logger.info(f"    Buckets: {targets}")
                self.logger.info(f"    Workers: {self.worker_count}")
                self.logger.info(f"    Target directory: {self.working_dir}")
                self.logger.info(f"    Filtering query: {query}")

            # Start the workers
            for x in range(self.worker_count):
                if self.use_threading:
                    t = threading.Thread(target=backup_worker,
                                         args=(x, self.instance_id,
                                               self.working_dir))
                    t.setDaemon(True)
                    t.start()
                else:
                    p = Process(target=backup_worker,
                                args=(x, self.instance_id, self.working_dir))
                    p.start()
                    self.plist.append(p)

            # Start done thread
            dt = threading.Thread(target=self.done_thread,
                                  args=('Backup', ),
                                  name="Done thread")
            dt.setDaemon(True)
            dt.start()

            # Process data buckets
            for bucket_name in bucket_list:
                try:
                    collection = self.datastore.get_collection(bucket_name)
                    for item in collection.stream_search(query,
                                                         fl="id",
                                                         item_buffer_size=500,
                                                         as_obj=False):
                        self.worker_queue.push({
                            "bucket_name": bucket_name,
                            "key": item['id'],
                            "follow_keys": follow_keys
                        })

                except Exception as e:
                    self.cleanup()
                    if self.logger:
                        self.logger.execption(e)
                        self.logger.error(
                            "Error occurred while processing bucket %s." %
                            bucket_name)
                    self.bucket_error.append(bucket_name)

            for _ in range(self.worker_count):
                self.worker_queue.push({"stop": True})

            dt.join()
        except Exception as e:
            if self.logger:
                self.logger.execption(e)

    def restore(self):
        try:
            if self.logger:
                self.logger.info("\n------------------------")
                self.logger.info("----- Data Restore -----")
                self.logger.info("------------------------")
                self.logger.info(f"    Workers: {self.worker_count}")
                self.logger.info(f"    Target directory: {self.working_dir}")

            for x in range(self.worker_count):
                if self.use_threading:
                    t = threading.Thread(target=restore_worker,
                                         args=(x, self.instance_id,
                                               self.working_dir))
                    t.setDaemon(True)
                    t.start()
                else:
                    p = Process(target=restore_worker,
                                args=(x, self.instance_id, self.working_dir))
                    p.start()
                    self.plist.append(p)

            # Start done thread
            dt = threading.Thread(target=self.done_thread,
                                  args=('Restore', ),
                                  name="Done thread")
            dt.setDaemon(True)
            dt.start()

            # Wait for workers to finish
            dt.join()
        except Exception as e:
            if self.logger:
                self.logger.execption(e)
Exemple #19
0
class DirectClient(ClientBase):
    def __init__(self, log, alert_fqs=None, submission_fqs=None, lookback_time='*'):
        # Setup datastore
        config = forge.get_config()
        redis = get_client(config.core.redis.nonpersistent.host, config.core.redis.nonpersistent.port, False)

        self.datastore = forge.get_datastore(config=config)
        self.alert_queue = NamedQueue("replay_alert", host=redis)
        self.file_queue = NamedQueue("replay_file", host=redis)
        self.submission_queue = NamedQueue("replay_submission", host=redis)

        super().__init__(log, alert_fqs=alert_fqs, submission_fqs=submission_fqs, lookback_time=lookback_time)

    def _get_next_alert_ids(self, query, filter_queries):
        return self.datastore.alert.search(
            query, fl="alert_id,reporting_ts", sort="reporting_ts asc", rows=100, as_obj=False, filters=filter_queries)

    def _get_next_submission_ids(self, query, filter_queries):
        return self.datastore.submission.search(
            query, fl="sid,times.completed", sort="times.completed asc", rows=100, as_obj=False, filters=filter_queries)

    def _set_bulk_alert_pending(self, query, filter_queries, max_docs):
        operations = [(self.datastore.alert.UPDATE_SET, 'metadata.replay', REPLAY_PENDING)]
        self.datastore.alert.update_by_query(query, operations, filters=filter_queries, max_docs=max_docs)

    def _set_bulk_submission_pending(self, query, filter_queries, max_docs):
        operations = [(self.datastore.submission.UPDATE_SET, 'metadata.replay', REPLAY_PENDING)]
        self.datastore.submission.update_by_query(query, operations, filters=filter_queries, max_docs=max_docs)

    def _stream_alert_ids(self, query):
        return self.datastore.alert.stream_search(query, fl="alert_id,reporting_ts", as_obj=False)

    def _stream_submission_ids(self, query):
        return self.datastore.submission.stream_search(query, fl="sid,times.completed", as_obj=False)

    def create_alert_bundle(self, alert_id, bundle_path):
        temp_bundle_file = create_bundle(alert_id, working_dir=os.path.dirname(bundle_path), use_alert=True)
        os.rename(temp_bundle_file, bundle_path)

    def create_submission_bundle(self, sid, bundle_path):
        temp_bundle_file = create_bundle(sid, working_dir=os.path.dirname(bundle_path))
        os.rename(temp_bundle_file, bundle_path)

    def load_bundle(self, bundle_path, min_classification, rescan_services, exist_ok=True):
        import_bundle(bundle_path,
                      min_classification=min_classification,
                      rescan_services=rescan_services,
                      exist_ok=exist_ok)

    def set_single_alert_complete(self, alert_id):
        operations = [(self.datastore.alert.UPDATE_SET, 'metadata.replay', REPLAY_DONE)]
        self.datastore.alert.update(alert_id, operations)

    def set_single_submission_complete(self, sid):
        operations = [(self.datastore.submission.UPDATE_SET, 'metadata.replay', REPLAY_DONE)]
        self.datastore.submission.update(sid, operations)

    def get_next_alert(self):
        return self.alert_queue.pop(blocking=True, timeout=30)

    def get_next_file(self):
        return self.file_queue.pop(blocking=True, timeout=30)

    def get_next_submission(self):
        return self.submission_queue.pop(blocking=True, timeout=30)

    def put_alert(self, alert):
        self.alert_queue.push(alert)

    def put_file(self, path):
        self.file_queue.push(path)

    def put_submission(self, submission):
        self.submission_queue.push(submission)