Esempio n. 1
0
    def __init__(self,
                 datastore=None,
                 redis=None,
                 redis_persist=None,
                 logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.timeout_watcher = WatcherClient(redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger(
            "assemblyline.dispatching.client")
        self.results = datastore.result
        self.errors = datastore.error
        self.files = datastore.file
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH,
                                               host=redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH,
                                          host=self.redis)
        self.service_data = cast(Dict[str, Service],
                                 CachedObject(self._get_services))
Esempio n. 2
0
    def __init__(self, redis=None, redis_persist=None):
        super().__init__('assemblyline.watcher',
                         redis=redis,
                         redis_persist=redis_persist)

        # Watcher structures
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=self.redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, self.redis_persist)

        # Task management structures
        self.running_tasks = ExpiringHash(
            DISPATCH_RUNNING_TASK_HASH,
            host=self.redis)  # TODO, move to persistant?
        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)

        # Metrics tracking
        self.counter = MetricsFactory(metrics_type='watcher',
                                      schema=Metrics,
                                      name='watcher',
                                      redis=self.redis,
                                      config=self.config)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="watcher")
        else:
            self.apm_client = None
Esempio n. 3
0
 def __init__(self,
              datastore: AssemblylineDatastore = None,
              filestore: FileStore = None,
              config=None,
              redis=None,
              redis_persist=None,
              identify=None):
     self.log = logging.getLogger('assemblyline.tasking_client')
     self.config = config or forge.CachedObject(forge.get_config)
     self.datastore = datastore or forge.get_datastore(self.config)
     self.dispatch_client = DispatchClient(self.datastore,
                                           redis=redis,
                                           redis_persist=redis_persist)
     self.event_sender = EventSender('changes.services', redis)
     self.filestore = filestore or forge.get_filestore(self.config)
     self.heuristic_handler = HeuristicHandler(self.datastore)
     self.heuristics = {
         h.heur_id: h
         for h in self.datastore.list_all_heuristics()
     }
     self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                      ttl=60 * 30,
                                      host=redis)
     self.tag_safelister = forge.CachedObject(forge.get_tag_safelister,
                                              kwargs=dict(
                                                  log=self.log,
                                                  config=config,
                                                  datastore=self.datastore),
                                              refresh=300)
     if identify:
         self.cleanup = False
     else:
         self.cleanup = True
     self.identify = identify or forge.get_identify(
         config=self.config, datastore=self.datastore, use_cache=True)
Esempio n. 4
0
    def __init__(self, redis_persist=None):
        config = forge.get_config()

        self.redis = redis_persist or get_client(
            host=config.core.redis.persistent.host,
            port=config.core.redis.persistent.port,
            private=False,
        )
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, redis_persist)
def get_working_and_idle(redis, current_service):
    status_table = ExpiringHash(SERVICE_STATE_HASH, host=redis, ttl=30 * 60)
    service_data = status_table.items()

    busy = []
    idle = []
    for host, (service, state, time_limit) in service_data.items():
        if service == current_service:
            if time.time() < time_limit:
                if state == ServiceStatus.Running:
                    busy.append(host)
                else:
                    idle.append(host)
    return busy, idle
    def _cleanup_submission(self, task: SubmissionTask, file_list: List[str]):
        """Clean up code that is the same for canceled and finished submissions"""
        submission = task.submission
        sid = submission.sid

        # Erase the temporary data which may have accumulated during processing
        for file_hash in file_list:
            hash_name = get_temporary_submission_data_name(sid, file_hash=file_hash)
            ExpiringHash(hash_name, host=self.redis).delete()

        if submission.params.quota_item and submission.params.submitter:
            self.log.info(f"[{sid}] Submission no longer counts toward {submission.params.submitter.upper()} quota")
            Hash('submissions-' + submission.params.submitter, self.redis_persist).pop(sid)

        if task.completed_queue:
            self.volatile_named_queue(task.completed_queue).push(submission.as_primitives())

        # Send complete message to any watchers.
        watcher_list = ExpiringSet(make_watcher_list_name(sid), host=self.redis)
        for w in watcher_list.members():
            NamedQueue(w).push(WatchQueueMessage({'status': 'STOP'}).as_primitives())

        # Clear the timeout watcher
        watcher_list.delete()
        self.timeout_watcher.clear(sid)
        self.active_submissions.pop(sid)

        # Count the submission as 'complete' either way
        self.counter.increment('submissions_completed')
Esempio n. 7
0
def test_expiring_hash(redis_connection):
    if redis_connection:
        from assemblyline.remote.datatypes.hash import ExpiringHash
        with ExpiringHash('test-expiring-hashmap', ttl=1) as eh:
            assert eh.add("key", "value") == 1
            assert eh.length() == 1
            time.sleep(1.1)
            assert eh.length() == 0
Esempio n. 8
0
    def __init__(self, datastore=None, filestore=None):
        super().__init__('assemblyline.randomservice')
        self.config = forge.get_config()
        self.datastore = datastore or forge.get_datastore()
        self.filestore = filestore or forge.get_filestore()
        self.client_id = get_random_id()
        self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60)

        self.counters = {
            n: MetricsFactory('service', Metrics, name=n, config=self.config)
            for n in self.datastore.service_delta.keys()
        }
        self.queues = [
            forge.get_service_queue(name)
            for name in self.datastore.service_delta.keys()
        ]
        self.dispatch_client = DispatchClient(self.datastore)
        self.service_info = CachedObject(self.datastore.list_all_services,
                                         kwargs={'as_obj': False})
    def __init__(self, datastore, redis, redis_persist, logger, counter_name='dispatcher'):
        # Load the datastore collections that we are going to be using
        self.datastore: AssemblylineDatastore = datastore
        self.log: logging.Logger = logger
        self.submissions: Collection = datastore.submission
        self.results: Collection = datastore.result
        self.errors: Collection = datastore.error
        self.files: Collection = datastore.file

        # Create a config cache that will refresh config values periodically
        self.config: Config = forge.get_config()

        # Connect to all of our persistent redis structures
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Build some utility classes
        self.scheduler = Scheduler(datastore, self.config, self.redis)
        self.classification_engine = forge.get_classification()
        self.timeout_watcher = WatcherClient(self.redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self._nonper_other_queues = {}
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)

        # Publish counters to the metrics sink.
        self.counter = MetricsFactory(metrics_type='dispatcher', schema=Metrics, name=counter_name,
                                      redis=self.redis, config=self.config)
    def __init__(self,
                 sid: str,
                 client: Union[Redis, StrictRedis],
                 fetch_results=False):
        """

        :param sid:
        :param client:
        :param fetch_results: Preload all the results on the redis server.
        """
        self.client = client
        self.sid = sid
        self._dispatch_key = f'{sid}{dispatch_tail}'
        self._finish_key = f'{sid}{finished_tail}'
        self._finish = self.client.register_script(finish_script)

        # cache the schedules calculated for the dispatcher, used to prevent rebuilding the
        # schedule repeatedly, and for telling the UI what services are pending
        self.schedules = ExpiringHash(f'dispatch-hash-schedules-{sid}',
                                      host=self.client)

        # How many services are outstanding for each file in the submission
        self._outstanding_service_count = ExpiringHash(
            f'dispatch-hash-files-{sid}', host=self.client)
        # Track which files have been extracted by what, in order to rebuild the file tree later
        self._file_tree = ExpiringSet(f'dispatch-hash-parents-{sid}',
                                      host=self.client)
        self._attempts = ExpiringHash(f'dispatch-hash-attempts-{sid}',
                                      host=self.client)

        # Local caches for _files and finished table
        self._cached_files = set(self._outstanding_service_count.keys())
        self._cached_results = dict()
        if fetch_results:
            self._cached_results = self.all_results()

        # Errors that are related to a submission, but not the terminal errors of a service
        self._other_errors = ExpiringSet(f'dispatch-hash-errors-{sid}',
                                         host=self.client)

        # TODO set these expire times from the global time limit for submissions
        retry_call(self.client.expire, self._dispatch_key, 60 * 60)
        retry_call(self.client.expire, self._finish_key, 60 * 60)
Esempio n. 11
0
class WatcherClient:
    def __init__(self, redis_persist=None):
        config = forge.get_config()

        self.redis = redis_persist or get_client(
            host=config.core.redis.persistent.host,
            port=config.core.redis.persistent.port,
            private=False,
        )
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, redis_persist)

    def touch(self, timeout: int, key: str, queue: str, message: dict):
        if timeout >= MAX_TIMEOUT:
            raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}")
        self.hash.set(key, {
            'action': WatcherAction.Message,
            'queue': queue,
            'message': message
        })
        seconds, _ = retry_call(self.redis.time)
        self.queue.push(int(seconds + timeout), key)

    def touch_task(self, timeout: int, key: str, worker: str, task_key: str):
        if timeout >= MAX_TIMEOUT:
            raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}")
        self.hash.set(
            key, {
                'action': WatcherAction.TimeoutTask,
                'worker': worker,
                'task_key': task_key
            })
        seconds, _ = retry_call(self.redis.time)
        self.queue.push(int(seconds + timeout), key)

    def clear(self, key: str):
        self.queue.remove(key)
        self.hash.pop(key)
Esempio n. 12
0
    def submit(self, submission_obj: SubmissionObject, local_files: List = None, cleanup=True, completed_queue=None):
        """Submit several files in a single submission.

        After this method runs, there should be no local copies of the file left.
        """
        if local_files is None:
            local_files = []

        try:
            expiry = now_as_iso(submission_obj.params.ttl * 24 * 60 * 60) if submission_obj.params.ttl else None
            max_size = self.config.submission.max_file_size

            if len(submission_obj.files) == 0:
                if len(local_files) == 0:
                    raise SubmissionException("No files found to submit...")

                for local_file in local_files:
                    # Upload/download, extract, analyze files
                    file_hash, size, new_metadata = self._ready_file(local_file, expiry,
                                                                     str(submission_obj.params.classification),
                                                                     cleanup, upload=True)
                    new_name = new_metadata.pop('name', safe_str(os.path.basename(local_file)))
                    submission_obj.params.classification = new_metadata.pop('classification',
                                                                            submission_obj.params.classification)
                    submission_obj.metadata.update(**flatten(new_metadata))

                    # Check that after we have resolved exactly what to pass on, that it
                    # remains a valid target for scanning
                    if size > max_size and not submission_obj.params.ignore_size:
                        msg = "File too large (%d > %d). Submission failed" % (size, max_size)
                        raise SubmissionException(msg)
                    elif size == 0:
                        msg = "File empty. Submission failed"
                        raise SubmissionException(msg)

                    submission_obj.files.append(File({
                        'name': new_name,
                        'size': size,
                        'sha256': file_hash,
                    }))
            else:
                for f in submission_obj.files:
                    temporary_path = None
                    try:
                        fd, temporary_path = tempfile.mkstemp(prefix="submission.submit")
                        os.close(fd)  # We don't need the file descriptor open
                        self.filestore.download(f.sha256, temporary_path)
                        file_hash, size, new_metadata = self._ready_file(temporary_path, expiry,
                                                                         str(submission_obj.params.classification),
                                                                         cleanup, sha256=f.sha256)

                        new_name = new_metadata.pop('name', f.name)
                        submission_obj.params.classification = new_metadata.pop('classification',
                                                                                submission_obj.params.classification)
                        submission_obj.metadata.update(**flatten(new_metadata))

                        # Check that after we have resolved exactly what to pass on, that it
                        # remains a valid target for scanning
                        if size > max_size and not submission_obj.params.ignore_size:
                            msg = "File too large (%d > %d). Submission failed" % (size, max_size)
                            raise SubmissionException(msg)
                        elif size == 0:
                            msg = "File empty. Submission failed"
                            raise SubmissionException(msg)

                        if f.size is None:
                            f.size = size

                        f.name = new_name
                        f.sha256 = file_hash

                    finally:
                        if temporary_path:
                            if os.path.exists(temporary_path):
                                os.unlink(temporary_path)

            # Initialize the temporary data from the submission parameter
            if submission_obj.params.initial_data:
                try:
                    temp_hash_name = get_temporary_submission_data_name(submission_obj.sid,
                                                                        submission_obj.files[0].sha256)
                    temporary_submission_data = ExpiringHash(temp_hash_name, host=self.redis)
                    temporary_submission_data.multi_set(json.loads(submission_obj.params.initial_data))
                except ValueError as err:
                    self.log.warning(f"[{submission_obj.sid}] could not process initialization data: {err}")

            # Clearing runtime_excluded on initial submit or resubmit
            submission_obj.params.services.runtime_excluded = []

            # We should now have all the information we need to construct a submission object
            sub = Submission(dict(
                archive_ts=now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60),
                classification=submission_obj.params.classification,
                error_count=0,
                errors=[],
                expiry_ts=expiry,
                file_count=len(submission_obj.files),
                files=submission_obj.files,
                max_score=0,
                metadata=submission_obj.metadata,
                params=submission_obj.params,
                results=[],
                sid=submission_obj.sid,
                state='submitted'
            ))
            self.datastore.submission.save(sub.sid, sub)

            self.log.debug("Submission complete. Dispatching: %s", sub.sid)
            self.dispatcher.dispatch_submission(sub, completed_queue=completed_queue)

            return sub
        finally:
            # Just in case this method fails clean up local files
            if cleanup:
                for path in local_files:
                    if path and os.path.exists(path):
                        # noinspection PyBroadException
                        try:
                            os.unlink(path)
                        except Exception:
                            self.log.error("Couldn't delete dangling file %s", path)
Esempio n. 13
0
class RandomService(ServerBase):
    """Replaces everything past the dispatcher.

    Including service API, in the future probably include that in this test.
    """
    def __init__(self, datastore=None, filestore=None):
        super().__init__('assemblyline.randomservice')
        self.config = forge.get_config()
        self.datastore = datastore or forge.get_datastore()
        self.filestore = filestore or forge.get_filestore()
        self.client_id = get_random_id()
        self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60)

        self.counters = {
            n: MetricsFactory('service', Metrics, name=n, config=self.config)
            for n in self.datastore.service_delta.keys()
        }
        self.queues = [
            forge.get_service_queue(name)
            for name in self.datastore.service_delta.keys()
        ]
        self.dispatch_client = DispatchClient(self.datastore)
        self.service_info = CachedObject(self.datastore.list_all_services,
                                         kwargs={'as_obj': False})

    def run(self):
        self.log.info("Random service result generator ready!")
        self.log.info("Monitoring queues:")
        for q in self.queues:
            self.log.info(f"\t{q.name}")

        self.log.info("Waiting for messages...")
        while self.running:
            # Reset Idle flags
            for s in self.service_info:
                if s['enabled']:
                    self.service_state_hash.set(
                        f"{self.client_id}_{s['name']}",
                        (s['name'], ServiceStatus.Idle, time.time() + 30 + 5))

            message = select(*self.queues, timeout=1)
            if not message:
                continue

            archive_ts = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            if self.config.submission.dtl:
                expiry_ts = now_as_iso(self.config.submission.dtl * 24 * 60 *
                                       60)
            else:
                expiry_ts = None
            queue, msg = message
            task = ServiceTask(msg)

            if not self.dispatch_client.running_tasks.add(
                    task.key(), task.as_primitives()):
                continue

            # Set service busy flag
            self.service_state_hash.set(
                f"{self.client_id}_{task.service_name}",
                (task.service_name, ServiceStatus.Running,
                 time.time() + 30 + 5))

            # METRICS
            self.counters[task.service_name].increment('execute')
            # METRICS (not caching here so always miss)
            self.counters[task.service_name].increment('cache_miss')

            self.log.info(
                f"\tQueue {queue} received a new task for sid {task.sid}.")
            action = random.randint(1, 10)
            if action >= 2:
                if action > 8:
                    result = random_minimal_obj(Result)
                else:
                    result = random_model_obj(Result)
                result.sha256 = task.fileinfo.sha256
                result.response.service_name = task.service_name
                result.archive_ts = archive_ts
                result.expiry_ts = expiry_ts
                result.response.extracted = result.response.extracted[task.
                                                                      depth +
                                                                      2:]
                result.response.supplementary = result.response.supplementary[
                    task.depth + 2:]
                result_key = Result.help_build_key(
                    sha256=task.fileinfo.sha256,
                    service_name=task.service_name,
                    service_version='0',
                    is_empty=result.is_empty())

                self.log.info(
                    f"\t\tA result was generated for this task: {result_key}")

                new_files = result.response.extracted + result.response.supplementary
                for f in new_files:
                    if not self.datastore.file.get(f.sha256):
                        random_file = random_model_obj(File)
                        random_file.archive_ts = archive_ts
                        random_file.expiry_ts = expiry_ts
                        random_file.sha256 = f.sha256
                        self.datastore.file.save(f.sha256, random_file)
                    if not self.filestore.exists(f.sha256):
                        self.filestore.put(f.sha256, f.sha256)

                time.sleep(random.randint(0, 2))

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)

                # METRICS
                if result.result.score > 0:
                    self.counters[task.service_name].increment('scored')
                else:
                    self.counters[task.service_name].increment('not_scored')

            else:
                error = random_model_obj(Error)
                error.archive_ts = archive_ts
                error.expiry_ts = expiry_ts
                error.sha256 = task.fileinfo.sha256
                error.response.service_name = task.service_name
                error.type = random.choice(
                    ["EXCEPTION", "SERVICE DOWN", "SERVICE BUSY"])

                error_key = error.build_key('0')

                self.log.info(
                    f"\t\tA {error.response.status}:{error.type} "
                    f"error was generated for this task: {error_key}")

                self.dispatch_client.service_failed(task.sid, error_key, error)

                # METRICS
                if error.response.status == "FAIL_RECOVERABLE":
                    self.counters[task.service_name].increment(
                        'fail_recoverable')
                else:
                    self.counters[task.service_name].increment(
                        'fail_nonrecoverable')
class DispatchHash:
    def __init__(self,
                 sid: str,
                 client: Union[Redis, StrictRedis],
                 fetch_results=False):
        """

        :param sid:
        :param client:
        :param fetch_results: Preload all the results on the redis server.
        """
        self.client = client
        self.sid = sid
        self._dispatch_key = f'{sid}{dispatch_tail}'
        self._finish_key = f'{sid}{finished_tail}'
        self._finish = self.client.register_script(finish_script)

        # cache the schedules calculated for the dispatcher, used to prevent rebuilding the
        # schedule repeatedly, and for telling the UI what services are pending
        self.schedules = ExpiringHash(f'dispatch-hash-schedules-{sid}',
                                      host=self.client)

        # How many services are outstanding for each file in the submission
        self._outstanding_service_count = ExpiringHash(
            f'dispatch-hash-files-{sid}', host=self.client)
        # Track which files have been extracted by what, in order to rebuild the file tree later
        self._file_tree = ExpiringSet(f'dispatch-hash-parents-{sid}',
                                      host=self.client)
        self._attempts = ExpiringHash(f'dispatch-hash-attempts-{sid}',
                                      host=self.client)

        # Local caches for _files and finished table
        self._cached_files = set(self._outstanding_service_count.keys())
        self._cached_results = dict()
        if fetch_results:
            self._cached_results = self.all_results()

        # Errors that are related to a submission, but not the terminal errors of a service
        self._other_errors = ExpiringSet(f'dispatch-hash-errors-{sid}',
                                         host=self.client)

        # TODO set these expire times from the global time limit for submissions
        retry_call(self.client.expire, self._dispatch_key, 60 * 60)
        retry_call(self.client.expire, self._finish_key, 60 * 60)

    def add_file(self, file_hash: str, file_limit, parent_hash) -> bool:
        """Add a file to a submission.

        Returns: Whether the file could be added to the submission or has been rejected.
        """
        if parent_hash:
            self._file_tree.add(f'{file_hash}-{parent_hash}')
        else:
            self._file_tree.add(file_hash)

        # If it was already in the set, we don't need to check remotely
        if file_hash in self._cached_files:
            return True

        # If the set is already full, and its not in the set, then we don't need to check remotely
        if len(self._cached_files) >= file_limit:
            return False

        # Our local checks are unclear, check remotely,
        # 0 => already exists, still want to return true
        # 1 => didn't exist before
        # None => over size limit, return false
        if self._outstanding_service_count.limited_add(file_hash, 0,
                                                       file_limit) is not None:
            # If it was added, add it to the local cache so we don't need to check again
            self._cached_files.add(file_hash)
            return True
        return False

    def add_error(self, error_key: str) -> bool:
        """Add an error to a submission.

        NOTE: This method is for errors occuring outside of any errors handled via 'fail_*recoverable'

        Returns true if the error is new, false if the error is a duplicate.
        """
        return self._other_errors.add(error_key) > 0

    def dispatch(self, file_hash: str, service: str):
        """Mark that a service has been dispatched for the given sha."""
        if retry_call(self.client.hset, self._dispatch_key,
                      f"{file_hash}-{service}", time.time()):
            self._outstanding_service_count.increment(file_hash, 1)

    def drop_dispatch(self, file_hash: str, service: str):
        """If a dispatch has been found to be un-needed remove the counters."""
        if retry_call(self.client.hdel, self._dispatch_key,
                      f"{file_hash}-{service}"):
            self._outstanding_service_count.increment(file_hash, -1)

    def dispatch_count(self):
        """How many tasks have been dispatched for this submission."""
        return retry_call(self.client.hlen, self._dispatch_key)

    def dispatch_time(self, file_hash: str, service: str) -> float:
        """When was dispatch called for this sha/service pair."""
        result = retry_call(self.client.hget, self._dispatch_key,
                            f"{file_hash}-{service}")
        if result is None:
            return 0
        return float(result)

    def all_dispatches(self) -> Dict[str, Dict[str, float]]:
        """Load the entire table of things that should currently be running."""
        rows = retry_call(self.client.hgetall, self._dispatch_key)
        output = {}
        for key, timestamp in rows.items():
            file_hash, service = key.split(b'-', maxsplit=1)
            file_hash = file_hash.decode()
            service = service.decode()
            if file_hash not in output:
                output[file_hash] = {}
            output[file_hash][service] = float(timestamp)
        return output

    def fail_recoverable(self,
                         file_hash: str,
                         service: str,
                         error_key: str = None):
        """A service task has failed, but should be retried, clear that it has been dispatched.

        After this call, the service is in a non-dispatched state, and the status can't be update
        until it is dispatched again.
        """
        if error_key:
            self._other_errors.add(error_key)
        retry_call(self.client.hdel, self._dispatch_key,
                   f"{file_hash}-{service}")
        self._outstanding_service_count.increment(file_hash, -1)

    def fail_nonrecoverable(self, file_hash: str, service,
                            error_key) -> Tuple[int, bool]:
        """A service task has failed and should not be retried, entry the error as the result.

        Has exactly the same semantics as `finish` but for errors.
        """
        return retry_call(self._finish,
                          args=[
                              self.sid, file_hash, service,
                              json.dumps(['error', error_key, 0, False, ''])
                          ])

    def finish(self,
               file_hash,
               service,
               result_key,
               score,
               classification,
               drop=False) -> Tuple[int, bool]:
        """
        As a single transaction:
         - Remove the service from the dispatched list
         - Add the file to the finished list, with the given result key
         - return the number of items in the dispatched list and if this was a duplicate call to finish
        """
        return retry_call(self._finish,
                          args=[
                              self.sid, file_hash, service,
                              json.dumps([
                                  'result', result_key, score, drop,
                                  str(classification)
                              ])
                          ])

    def finished_count(self) -> int:
        """How many tasks have been finished for this submission."""
        return retry_call(self.client.hlen, self._finish_key)

    def finished(self, file_hash, service) -> Union[DispatchRow, None]:
        """If a service has been finished, return the key of the result document."""
        # Try the local cache
        result = self._cached_results.get(file_hash, {}).get(service, None)
        if result:
            return result
        # Try the server
        result = retry_call(self.client.hget, self._finish_key,
                            f"{file_hash}-{service}")
        if result:
            return DispatchRow(*json.loads(result))
        return None

    def all_finished(self) -> bool:
        """Are there no outstanding tasks, and at least one finished task."""
        return self.finished_count() > 0 and self.dispatch_count() == 0

    def all_results(self) -> Dict[str, Dict[str, DispatchRow]]:
        """Get all the records stored in the dispatch table.

        :return: output[file_hash][service_name] -> DispatchRow
        """
        rows = retry_call(self.client.hgetall, self._finish_key)
        output = {}
        for key, status in rows.items():
            file_hash, service = key.split(b'-', maxsplit=1)
            file_hash = file_hash.decode()
            service = service.decode()
            if file_hash not in output:
                output[file_hash] = {}
            output[file_hash][service] = DispatchRow(*json.loads(status))
        return output

    def all_extra_errors(self):
        """Return the set of errors not part of the dispatch table itself."""
        return self._other_errors.members()

    def all_files(self):
        return self._outstanding_service_count.keys()

    def file_tree(self):
        """Returns a mapping from file, to a list of files that are that file's parents.

        A none value being in the list indicates that the file is one of the root files of the submission.
        """
        edges = self._file_tree.members()
        output = {}
        for string in edges:
            if '-' in string:
                child, parent = string.split('-')
            else:
                child, parent = string, None

            if child not in output:
                output[child] = []
            output[child].append(parent)
        return output

    def delete(self):
        """Clear the tables from the redis server."""
        retry_call(self.client.delete, self._dispatch_key)
        retry_call(self.client.delete, self._finish_key)
        self.schedules.delete()
        self._outstanding_service_count.delete()
        self._file_tree.delete()
        self._other_errors.delete()
        self._attempts.delete()
def client(redis, storage, heuristics, dispatch_client):
    client = app.app.test_client()
    task.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                     ttl=60 * 30,
                                     host=redis)
    yield client
Esempio n. 16
0
class WatcherServer(CoreBase):
    def __init__(self, redis=None, redis_persist=None):
        super().__init__('assemblyline.watcher',
                         redis=redis,
                         redis_persist=redis_persist)

        # Watcher structures
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=self.redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, self.redis_persist)

        # Task management structures
        self.running_tasks = ExpiringHash(
            DISPATCH_RUNNING_TASK_HASH,
            host=self.redis)  # TODO, move to persistant?
        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)

        # Metrics tracking
        self.counter = MetricsFactory(metrics_type='watcher',
                                      schema=Metrics,
                                      name='watcher',
                                      redis=self.redis,
                                      config=self.config)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="watcher")
        else:
            self.apm_client = None

    def try_run(self):
        counter = self.counter
        apm_client = self.apm_client

        while self.running:
            self.heartbeat()

            # Download all messages from the queue that have expired
            seconds, _ = retry_call(self.redis.time)
            messages = self.queue.dequeue_range(0, seconds)

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Try to pass on all the messages to their intended recipient, try not to let
            # the failure of one message from preventing the others from going through
            for key in messages:
                # Start of transaction
                if apm_client:
                    apm_client.begin_transaction('process_messages')

                message = self.hash.pop(key)
                if message:
                    try:
                        if message['action'] == WatcherAction.TimeoutTask:
                            self.cancel_service_task(message['task_key'],
                                                     message['worker'])
                        else:
                            queue = NamedQueue(message['queue'], self.redis)
                            queue.push(message['message'])

                        self.counter.increment('expired')
                        # End of transaction (success)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'success')
                    except Exception as error:
                        # End of transaction (exception)
                        if apm_client:
                            apm_client.end_transaction('watch_message',
                                                       'error')

                        self.log.exception(error)
                else:
                    # End of transaction (duplicate)
                    if apm_client:
                        apm_client.end_transaction('watch_message',
                                                   'duplicate')

                    self.log.warning(
                        f'Handled watch twice: {key} {len(key)} {type(key)}')

            counter.increment_execution_time('cpu_seconds',
                                             time.process_time() - cpu_mark)
            counter.increment_execution_time('busy_seconds',
                                             time.time() - time_mark)

            if not messages:
                time.sleep(0.1)

    def cancel_service_task(self, task_key, worker):
        # We believe a service task has timed out, try and read it from running tasks
        # If we can't find the task in running tasks, it finished JUST before timing out, let it go
        task = self.running_tasks.pop(task_key)
        if not task:
            return

        # We can confirm that the task is ours now, even if the worker finished, the result will be ignored
        task = Task(task)
        self.log.info(
            f"[{task.sid}] Service {task.service_name} timed out on {task.fileinfo.sha256}."
        )

        # Mark the previous attempt as invalid and redispatch it
        dispatch_table = DispatchHash(task.sid, self.redis)
        dispatch_table.fail_recoverable(task.fileinfo.sha256,
                                        task.service_name)
        dispatch_table.dispatch(task.fileinfo.sha256, task.service_name)
        get_service_queue(task.service_name,
                          self.redis).push(task.priority, task.as_primitives())

        # We push the task of killing the container off on the scaler, which already has root access
        # the scaler can also double check that the service name and container id match, to be sure
        # we aren't accidentally killing the wrong container
        self.scaler_timeout_queue.push({
            'service': task.service_name,
            'container': worker
        })

        # Report to the metrics system that a recoverable error has occurred for that service
        export_metrics_once(task.service_name,
                            ServiceMetrics,
                            dict(fail_recoverable=1),
                            host=worker,
                            counter_type='service')
    def __init__(self,
                 config=None,
                 datastore=None,
                 redis=None,
                 redis_persist=None):
        super().__init__('assemblyline.scaler',
                         config=config,
                         datastore=datastore,
                         redis=redis,
                         redis_persist=redis_persist)

        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)
        self.error_count = {}
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         host=self.redis,
                                         ttl=30 * 60)

        labels = {
            'app': 'assemblyline',
            'section': 'service',
        }

        if KUBERNETES_AL_CONFIG:
            self.log.info(
                f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}"
            )
            self.controller = KubernetesController(
                logger=self.log,
                prefix='alsvc_',
                labels=labels,
                namespace=NAMESPACE,
                priority='al-service-priority')
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_CONFIGMAP:
                self.controller.config_mount(
                    'classification-config',
                    config_map=CLASSIFICATION_CONFIGMAP,
                    key=CLASSIFICATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/classification.yml')
        else:
            self.log.info("Loading Docker cluster interface.")
            self.controller = DockerController(
                logger=self.log,
                prefix=NAMESPACE,
                cpu_overallocation=self.config.core.scaler.cpu_overallocation,
                memory_overallocation=self.config.core.scaler.
                memory_overallocation,
                labels=labels)
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_HOST_PATH:
                self.controller.global_mounts.append(
                    (CLASSIFICATION_HOST_PATH,
                     '/etc/assemblyline/classification.yml'))

        self.profiles: Dict[str, ServiceProfile] = {}

        # Prepare a single threaded scheduler
        self.state = collection.Collection(
            period=self.config.core.metrics.export_interval)
        self.scheduler = sched.scheduler()
        self.scheduler_stopped = threading.Event()
Esempio n. 18
0
class TaskingClient:
    """A helper class to simplify tasking for privileged services and service-server.

    This tool helps take care of interactions between the filestore,
    datastore, dispatcher, and any sources of files to be processed.
    """
    def __init__(self,
                 datastore: AssemblylineDatastore = None,
                 filestore: FileStore = None,
                 config=None,
                 redis=None,
                 redis_persist=None,
                 identify=None):
        self.log = logging.getLogger('assemblyline.tasking_client')
        self.config = config or forge.CachedObject(forge.get_config)
        self.datastore = datastore or forge.get_datastore(self.config)
        self.dispatch_client = DispatchClient(self.datastore,
                                              redis=redis,
                                              redis_persist=redis_persist)
        self.event_sender = EventSender('changes.services', redis)
        self.filestore = filestore or forge.get_filestore(self.config)
        self.heuristic_handler = HeuristicHandler(self.datastore)
        self.heuristics = {
            h.heur_id: h
            for h in self.datastore.list_all_heuristics()
        }
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         ttl=60 * 30,
                                         host=redis)
        self.tag_safelister = forge.CachedObject(forge.get_tag_safelister,
                                                 kwargs=dict(
                                                     log=self.log,
                                                     config=config,
                                                     datastore=self.datastore),
                                                 refresh=300)
        if identify:
            self.cleanup = False
        else:
            self.cleanup = True
        self.identify = identify or forge.get_identify(
            config=self.config, datastore=self.datastore, use_cache=True)

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.stop()

    def stop(self):
        if self.cleanup:
            self.identify.stop()

    @elasticapm.capture_span(span_type='tasking_client')
    def upload_file(self,
                    file_path,
                    classification,
                    ttl,
                    is_section_image,
                    expected_sha256=None):
        # Identify the file info of the uploaded file
        file_info = self.identify.fileinfo(file_path)

        # Validate SHA256 of the uploaded file
        if expected_sha256 is None or expected_sha256 == file_info['sha256']:
            file_info['archive_ts'] = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            file_info['classification'] = classification
            if ttl:
                file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60)
            else:
                file_info['expiry_ts'] = None

            # Update the datastore with the uploaded file
            self.datastore.save_or_freshen_file(
                file_info['sha256'],
                file_info,
                file_info['expiry_ts'],
                file_info['classification'],
                is_section_image=is_section_image)

            # Upload file to the filestore (upload already checks if the file exists)
            self.filestore.upload(file_path, file_info['sha256'])
        else:
            raise TaskingClientException(
                "Uploaded file does not match expected file hash. "
                f"[{file_info['sha256']} != {expected_sha256}]")

    # Service
    @elasticapm.capture_span(span_type='tasking_client')
    def register_service(self, service_data, log_prefix=""):
        keep_alive = True

        try:
            # Get heuristics list
            heuristics = service_data.pop('heuristics', None)

            # Patch update_channel, registry_type before Service registration object creation
            service_data['update_channel'] = service_data.get(
                'update_channel',
                self.config.services.preferred_update_channel)
            service_data['docker_config']['registry_type'] = service_data['docker_config'] \
                .get('registry_type', self.config.services.preferred_registry_type)
            service_data['privileged'] = service_data.get(
                'privileged', self.config.services.prefer_service_privileged)
            for dep in service_data.get('dependencies', {}).values():
                dep['container']['registry_type'] = dep.get(
                    'registry_type',
                    self.config.services.preferred_registry_type)

            # Pop unused registration service_data
            for x in ['file_required', 'tool_version']:
                service_data.pop(x, None)

            # Create Service registration object
            service = Service(service_data)

            # Fix service version, we don't need to see the stable label
            service.version = service.version.replace('stable', '')

            # Save service if it doesn't already exist
            if not self.datastore.service.exists(
                    f'{service.name}_{service.version}'):
                self.datastore.service.save(
                    f'{service.name}_{service.version}', service)
                self.datastore.service.commit()
                self.log.info(f"{log_prefix}{service.name} registered")
                keep_alive = False

            # Save service delta if it doesn't already exist
            if not self.datastore.service_delta.exists(service.name):
                self.datastore.service_delta.save(service.name,
                                                  {'version': service.version})
                self.datastore.service_delta.commit()
                self.log.info(f"{log_prefix}{service.name} "
                              f"version ({service.version}) registered")

            new_heuristics = []
            if heuristics:
                plan = self.datastore.heuristic.get_bulk_plan()
                for index, heuristic in enumerate(heuristics):
                    heuristic_id = f'#{index}'  # Set heuristic id to it's position in the list for logging purposes
                    try:
                        # Append service name to heuristic ID
                        heuristic[
                            'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}"

                        # Attack_id field is now a list, make it a list if we receive otherwise
                        attack_id = heuristic.get('attack_id', None)
                        if isinstance(attack_id, str):
                            heuristic['attack_id'] = [attack_id]

                        heuristic = Heuristic(heuristic)
                        heuristic_id = heuristic.heur_id
                        plan.add_upsert_operation(heuristic_id, heuristic)
                    except Exception as e:
                        msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}"
                        self.log.exception(f"{log_prefix}{msg}")
                        raise ValueError(msg)

                for item in self.datastore.heuristic.bulk(plan)['items']:
                    if item['update']['result'] != "noop":
                        new_heuristics.append(item['update']['_id'])
                        self.log.info(
                            f"{log_prefix}{service.name} "
                            f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}"
                        )

                self.datastore.heuristic.commit()

            service_config = self.datastore.get_service_with_delta(
                service.name, as_obj=False)

            # Notify components watching for service config changes
            self.event_sender.send(service.name, {
                'operation': Operation.Added,
                'name': service.name
            })

        except ValueError as e:  # Catch errors when building Service or Heuristic model(s)
            raise e

        return dict(keep_alive=keep_alive,
                    new_heuristics=new_heuristics,
                    service_config=service_config or dict())

    # Task
    @elasticapm.capture_span(span_type='tasking_client')
    def get_task(self,
                 client_id,
                 service_name,
                 service_version,
                 service_tool_version,
                 metric_factory,
                 status_expiry=None,
                 timeout=30):
        if status_expiry is None:
            status_expiry = time.time() + timeout

        cache_found = False

        try:
            service_data = self.dispatch_client.service_data[service_name]
        except KeyError:
            raise ServiceMissingException(
                "The service you're asking task for does not exist, try later",
                404)

        # Set the service status to Idle since we will be waiting for a task
        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, status_expiry))

        # Getting a new task
        task = self.dispatch_client.request_work(client_id,
                                                 service_name,
                                                 service_version,
                                                 timeout=timeout)

        if not task:
            # We've reached the timeout and no task found in service queue
            return None, False

        # We've got a task to process, consider us busy
        self.status_table.set(client_id, (service_name, ServiceStatus.Running,
                                          time.time() + service_data.timeout))
        metric_factory.increment('execute')

        result_key = Result.help_build_key(
            sha256=task.fileinfo.sha256,
            service_name=service_name,
            service_version=service_version,
            service_tool_version=service_tool_version,
            is_empty=False,
            task=task)

        # If we are allowed, try to see if the result has been cached
        if not task.ignore_cache and not service_data.disable_cache:
            # Checking for previous results for this key
            result = self.datastore.result.get_if_exists(result_key)
            if result:
                metric_factory.increment('cache_hit')
                if result.result.score:
                    metric_factory.increment('scored')
                else:
                    metric_factory.increment('not_scored')

                result.archive_ts = now_as_iso(
                    self.config.datastore.ilm.days_until_archive * 24 * 60 *
                    60)
                if task.ttl:
                    result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)
                cache_found = True

            if not cache_found:
                # Checking for previous empty results for this key
                result = self.datastore.emptyresult.get_if_exists(
                    f"{result_key}.e")
                if result:
                    metric_factory.increment('cache_hit')
                    metric_factory.increment('not_scored')
                    result = self.datastore.create_empty_result_from_key(
                        result_key)
                    self.dispatch_client.service_finished(
                        task.sid, f"{result_key}.e", result)
                    cache_found = True

            if not cache_found:
                metric_factory.increment('cache_miss')

        else:
            metric_factory.increment('cache_skipped')

        if not cache_found:
            # No luck with the cache, lets dispatch the task to a client
            return task.as_primitives(), False

        return None, True

    @elasticapm.capture_span(span_type='tasking_client')
    def task_finished(self, service_task, client_id, service_name,
                      metric_factory):
        exec_time = service_task.get('exec_time')

        try:
            task = ServiceTask(service_task['task'])

            if 'result' in service_task:  # Task created a result
                missing_files = self._handle_task_result(
                    exec_time, task, service_task['result'], client_id,
                    service_name, service_task['freshen'], metric_factory)
                if missing_files:
                    return dict(success=False, missing_files=missing_files)
                return dict(success=True)

            elif 'error' in service_task:  # Task created an error
                error = service_task['error']
                self._handle_task_error(exec_time, task, error, client_id,
                                        service_name, metric_factory)
                return dict(success=True)
            else:
                return None

        except ValueError as e:  # Catch errors when building Task or Result model
            raise e

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_result(self, exec_time: int, task: ServiceTask,
                            result: Dict[str, Any], client_id, service_name,
                            freshen: bool, metric_factory):
        def freshen_file(file_info_list, item):
            file_info = file_info_list.get(item['sha256'], None)
            if file_info is None or not self.filestore.exists(item['sha256']):
                return True
            else:
                file_info['archive_ts'] = archive_ts
                file_info['expiry_ts'] = expiry_ts
                file_info['classification'] = item['classification']
                self.datastore.save_or_freshen_file(
                    item['sha256'],
                    file_info,
                    file_info['expiry_ts'],
                    file_info['classification'],
                    is_section_image=item.get('is_section_image', False))
            return False

        archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive *
                                24 * 60 * 60)
        if task.ttl:
            expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)
        else:
            expiry_ts = None

        # Check if all files are in the filestore
        if freshen:
            missing_files = []
            hashes = list(
                set([
                    f['sha256'] for f in result['response']['extracted'] +
                    result['response']['supplementary']
                ]))
            file_infos = self.datastore.file.multiget(hashes,
                                                      as_obj=False,
                                                      error_on_missing=False)

            with elasticapm.capture_span(
                    name="handle_task_result.freshen_files",
                    span_type="tasking_client"):
                with concurrent.futures.ThreadPoolExecutor(
                        max_workers=5) as executor:
                    res = {
                        f['sha256']: executor.submit(freshen_file, file_infos,
                                                     f)
                        for f in result['response']['extracted'] +
                        result['response']['supplementary']
                    }
                for k, v in res.items():
                    if v.result():
                        missing_files.append(k)

            if missing_files:
                return missing_files

        # Add scores to the heuristics, if any section set a heuristic
        with elasticapm.capture_span(
                name="handle_task_result.process_heuristics",
                span_type="tasking_client"):
            total_score = 0
            for section in result['result']['sections']:
                zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True)
                section['tags'] = flatten(section['tags'])
                if section.get('heuristic'):
                    heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}"
                    section['heuristic']['heur_id'] = heur_id
                    try:
                        section[
                            'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic(
                                section['heuristic'], self.heuristics,
                                zeroize_on_sig_safe)
                        for tag in new_tags:
                            section['tags'].setdefault(tag[0], [])
                            if tag[1] not in section['tags'][tag[0]]:
                                section['tags'][tag[0]].append(tag[1])
                        total_score += section['heuristic']['score']
                    except InvalidHeuristicException:
                        section['heuristic'] = None

        # Update the total score of the result
        result['result']['score'] = total_score

        # Add timestamps for creation, archive and expiry
        result['created'] = now_as_iso()
        result['archive_ts'] = archive_ts
        result['expiry_ts'] = expiry_ts

        # Pop the temporary submission data
        temp_submission_data = result.pop('temp_submission_data', None)
        if temp_submission_data:
            old_submission_data = {
                row.name: row.value
                for row in task.temporary_submission_data
            }
            temp_submission_data = {
                k: v
                for k, v in temp_submission_data.items()
                if k not in old_submission_data or v != old_submission_data[k]
            }
            big_temp_data = {
                k: len(str(v))
                for k, v in temp_submission_data.items()
                if len(str(v)) > self.config.submission.max_temp_data_length
            }
            if big_temp_data:
                big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()]
                self.log.warning(
                    f"[{task.sid}] The following temporary submission keys where ignored because they are "
                    "bigger then the maximum data size allowed "
                    f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}"
                )
                temp_submission_data = {
                    k: v
                    for k, v in temp_submission_data.items()
                    if k not in big_temp_data
                }

        # Process the tag values
        with elasticapm.capture_span(name="handle_task_result.process_tags",
                                     span_type="tasking_client"):
            for section in result['result']['sections']:
                # Perform tag safelisting
                tags, safelisted_tags = self.tag_safelister.get_validated_tag_map(
                    section['tags'])
                section['tags'] = unflatten(tags)
                section['safelisted_tags'] = safelisted_tags

                section['tags'], dropped = construct_safe(
                    Tagging, section.get('tags', {}))

                # Set section score to zero and lower total score if service is set to zeroize score
                # and all tags were safelisted
                if section.pop('zeroize_on_tag_safe', False) and \
                        section.get('heuristic') and \
                        len(tags) == 0 and \
                        len(safelisted_tags) != 0:
                    result['result']['score'] -= section['heuristic']['score']
                    section['heuristic']['score'] = 0

                if dropped:
                    self.log.warning(
                        f"[{task.sid}] Invalid tag data from {service_name}: {dropped}"
                    )

        result = Result(result)
        result_key = result.build_key(
            service_tool_version=result.response.service_tool_version,
            task=task)
        self.dispatch_client.service_finished(task.sid, result_key, result,
                                              temp_submission_data)

        # Metrics
        if result.result.score > 0:
            metric_factory.increment('scored')
        else:
            metric_factory.increment('not_scored')

        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}"
        )

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_error(self, exec_time: int, task: ServiceTask,
                           error: Dict[str, Any], client_id, service_name,
                           metric_factory) -> None:
        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}"
        )

        # Add timestamps for creation, archive and expiry
        error['created'] = now_as_iso()
        error['archive_ts'] = now_as_iso(
            self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
        if task.ttl:
            error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60)

        error = Error(error)
        error_key = error.build_key(
            service_tool_version=error.response.service_tool_version,
            task=task)
        self.dispatch_client.service_failed(task.sid, error_key, error)

        # Metrics
        if error.response.status == 'FAIL_RECOVERABLE':
            metric_factory.increment('fail_recoverable')
        else:
            metric_factory.increment('fail_nonrecoverable')

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
Esempio n. 19
0
class ScalerServer(ThreadedCoreBase):
    def __init__(self,
                 config=None,
                 datastore=None,
                 redis=None,
                 redis_persist=None):
        super().__init__('assemblyline.scaler',
                         config=config,
                         datastore=datastore,
                         redis=redis,
                         redis_persist=redis_persist)

        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)
        self.error_count_lock = threading.Lock()
        self.error_count: dict[str, list[float]] = {}
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         host=self.redis,
                                         ttl=30 * 60)
        self.service_event_sender = EventSender('changes.services',
                                                host=self.redis)
        self.service_change_watcher = EventWatcher(
            self.redis, deserializer=ServiceChange.deserialize)
        self.service_change_watcher.register('changes.services.*',
                                             self._handle_service_change_event)

        core_env: dict[str, str] = {}
        # If we have privileged services, we must be able to pass the necessary environment variables for them to
        # function properly.
        for secret in re.findall(
                r'\${\w+}',
                open('/etc/assemblyline/config.yml',
                     'r').read()) + ['UI_SERVER']:
            env_name = secret.strip("${}")
            core_env[env_name] = os.environ[env_name]

        labels = {
            'app': 'assemblyline',
            'section': 'service',
            'privilege': 'service'
        }

        if self.config.core.scaler.additional_labels:
            labels.update({
                k: v
                for k, v in (
                    _l.split("=")
                    for _l in self.config.core.scaler.additional_labels)
            })

        if KUBERNETES_AL_CONFIG:
            self.log.info(
                f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}"
            )
            self.controller = KubernetesController(
                logger=self.log,
                prefix='alsvc_',
                labels=labels,
                namespace=NAMESPACE,
                priority='al-service-priority',
                cpu_reservation=self.config.services.cpu_reservation,
                log_level=self.config.logging.log_level,
                core_env=core_env)
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_CONFIGMAP:
                self.controller.config_mount(
                    'classification-config',
                    config_map=CLASSIFICATION_CONFIGMAP,
                    key=CLASSIFICATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/classification.yml')
            if CONFIGURATION_CONFIGMAP:
                self.controller.core_config_mount(
                    'assemblyline-config',
                    config_map=CONFIGURATION_CONFIGMAP,
                    key=CONFIGURATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/config.yml')
        else:
            self.log.info("Loading Docker cluster interface.")
            self.controller = DockerController(
                logger=self.log,
                prefix=NAMESPACE,
                labels=labels,
                log_level=self.config.logging.log_level,
                core_env=core_env)
            self._service_stage_hash.delete()

            if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME:
                self.controller.core_mounts.append(
                    (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/'))

                with open(
                        os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'),
                        'w') as handle:
                    yaml.dump(self.config.as_primitives(), handle)

                with open(
                        os.path.join(DOCKER_CONFIGURATION_PATH,
                                     'classification.yml'), 'w') as handle:
                    yaml.dump(get_classification().original_definition, handle)

            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_HOST_PATH:
                self.controller.global_mounts.append(
                    (CLASSIFICATION_HOST_PATH,
                     '/etc/assemblyline/classification.yml'))

        # Information about services
        self.profiles: dict[str, ServiceProfile] = {}
        self.profiles_lock = threading.RLock()

        # Prepare a single threaded scheduler
        self.state = collection.Collection(
            period=self.config.core.metrics.export_interval)
        self.stopping = threading.Event()
        self.main_loop_exit = threading.Event()

        # Load the APM connection if any
        self.apm_client = None
        if self.config.core.metrics.apm_server.server_url:
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="scaler")

    def log_crashes(self, fn):
        @functools.wraps(fn)
        def with_logs(*args, **kwargs):
            # noinspection PyBroadException
            try:
                fn(*args, **kwargs)
            except ServiceControlError as error:
                self.log.exception(
                    f"Error while managing service: {error.service_name}")
                self.handle_service_error(error.service_name)
            except Exception:
                self.log.exception(f'Crash in scaler: {fn.__name__}')

        return with_logs

    @elasticapm.capture_span(span_type=APM_SPAN_TYPE)
    def add_service(self, profile: ServiceProfile):
        # We need to hold the lock the whole time we add the service,
        # we don't want the scaling thread trying to adjust the scale of a
        # deployment we haven't added to the system yet
        with self.profiles_lock:
            profile.desired_instances = max(
                self.controller.get_target(profile.name),
                profile.min_instances)
            profile.running_instances = profile.desired_instances
            profile.target_instances = profile.desired_instances
            self.log.debug(
                f'Starting service {profile.name} with a target of {profile.desired_instances}'
            )
            profile.last_update = time.time()
            self.profiles[profile.name] = profile
            self.controller.add_profile(profile,
                                        scale=profile.desired_instances)

    def try_run(self):
        self.service_change_watcher.start()
        self.maintain_threads({
            'Log Container Events': self.log_container_events,
            'Process Timeouts': self.process_timeouts,
            'Service Configuration Sync': self.sync_services,
            'Service Adjuster': self.update_scaling,
            'Import Metrics': self.sync_metrics,
            'Export Metrics': self.export_metrics,
        })

    def stop(self):
        super().stop()
        self.service_change_watcher.stop()
        self.controller.stop()

    def _handle_service_change_event(self, data: ServiceChange):
        if data.operation == Operation.Removed:
            self.log.info(
                f'Service appears to be deleted, removing {data.name}')
            stage = self.get_service_stage(data.name)
            self.stop_service(data.name, stage)
        elif data.operation == Operation.Incompatible:
            return
        else:
            self._sync_service(self.datastore.get_service_with_delta(
                data.name))

    def sync_services(self):
        while self.running:
            with apm_span(self.apm_client, 'sync_services'):
                with self.profiles_lock:
                    current_services = set(self.profiles.keys())
                discovered_services: list[str] = []

                # Get all the service data
                for service in self.datastore.list_all_services(full=True):
                    self._sync_service(service)
                    discovered_services.append(service.name)

                # Find any services we have running, that are no longer in the database and remove them
                for stray_service in current_services - set(
                        discovered_services):
                    self.log.info(
                        f'Service appears to be deleted, removing stray {stray_service}'
                    )
                    stage = self.get_service_stage(stray_service)
                    self.stop_service(stray_service, stage)

            self.sleep(SERVICE_SYNC_INTERVAL)

    def _sync_service(self, service: Service):
        name = service.name
        stage = self.get_service_stage(service.name)
        default_settings = self.config.core.scaler.service_defaults
        image_variables: defaultdict[str, str] = defaultdict(str)
        image_variables.update(self.config.services.image_variables)

        def prepare_container(docker_config: DockerConfig) -> DockerConfig:
            docker_config.image = Template(
                docker_config.image).safe_substitute(image_variables)
            set_keys = set(var.name for var in docker_config.environment)
            for var in default_settings.environment:
                if var.name not in set_keys:
                    docker_config.environment.append(var)
            return docker_config

        # noinspection PyBroadException
        try:

            def disable_incompatible_service():
                service.enabled = False
                if self.datastore.service_delta.update(service.name, [
                    (self.datastore.service_delta.UPDATE_SET, 'enabled', False)
                ]):
                    # Raise awareness to other components by sending an event for the service
                    self.service_event_sender.send(service.name, {
                        'operation': Operation.Incompatible,
                        'name': service.name
                    })

            # Check if service considered compatible to run on Assemblyline?
            system_spec = f'{FRAMEWORK_VERSION}.{SYSTEM_VERSION}'
            if not service.version.startswith(system_spec):
                # If FW and SYS version don't prefix in the service version, we can't guarantee the service is compatible
                # Disable and treat it as incompatible due to service version.
                self.log.warning(
                    "Disabling service with incompatible version. "
                    f"[{service.version} != '{system_spec}.X.{service.update_channel}Y']."
                )
                disable_incompatible_service()
            elif service.update_config and service.update_config.wait_for_update and not service.update_config.sources:
                # All signatures sources from a signature-dependent service was removed
                # Disable and treat it as incompatible due to service configuration relative to source management
                self.log.warning(
                    "Disabling service with incompatible service configuration. "
                    "Signature-dependent service has no signature sources.")
                disable_incompatible_service()

            if not service.enabled:
                self.stop_service(service.name, stage)
                return

            # Build the docker config for the dependencies. For now the dependency blob values
            # aren't set for the change key going to kubernetes because everything about
            # the dependency config should be captured in change key that the function generates
            # internally. A change key is set for the service deployment as that includes
            # things like the submission params
            dependency_config: dict[str, Any] = {}
            dependency_blobs: dict[str, str] = {}
            for _n, dependency in service.dependencies.items():
                dependency.container = prepare_container(dependency.container)
                dependency_config[_n] = dependency
                dep_hash = get_id_from_data(dependency, length=16)
                dependency_blobs[
                    _n] = f"dh={dep_hash}v={service.version}p={service.privileged}"

            # Check if the service dependencies have been deployed.
            dependency_keys = []
            updater_ready = stage == ServiceStage.Running
            if service.update_config:
                for _n, dependency in dependency_config.items():
                    key = self.controller.stateful_container_key(
                        service.name, _n, dependency, '')
                    if key:
                        dependency_keys.append(_n + key)
                    else:
                        updater_ready = False

            # If stage is not set to running or a dependency container is missing start the setup process
            if not updater_ready:
                self.log.info(f'Preparing environment for {service.name}')
                # Move to the next service stage (do this first because the container we are starting may care)
                if service.update_config and service.update_config.wait_for_update:
                    self._service_stage_hash.set(name, ServiceStage.Update)
                    stage = ServiceStage.Update
                else:
                    self._service_stage_hash.set(name, ServiceStage.Running)
                    stage = ServiceStage.Running

                # Enable this service's dependencies before trying to launch the service containers
                dependency_internet = [
                    (name, dependency.container.allow_internet_access)
                    for name, dependency in dependency_config.items()
                ]

                self.controller.prepare_network(
                    service.name, service.docker_config.allow_internet_access,
                    dependency_internet)
                for _n, dependency in dependency_config.items():
                    self.log.info(f'Launching {service.name} dependency {_n}')
                    self.controller.start_stateful_container(
                        service_name=service.name,
                        container_name=_n,
                        spec=dependency,
                        labels={'dependency_for': service.name},
                        change_key=dependency_blobs.get(_n, ''))

            # If the conditions for running are met deploy or update service containers
            if stage == ServiceStage.Running:
                # Build the docker config for the service, we are going to either create it or
                # update it so we need to know what the current configuration is either way
                docker_config = prepare_container(service.docker_config)

                # Compute a blob of service properties not include in the docker config, that
                # should still result in a service being restarted when changed
                cfg_items = get_recursive_sorted_tuples(service.config)
                dep_keys = ''.join(sorted(dependency_keys))
                config_blob = (
                    f"c={cfg_items}sp={service.submission_params}"
                    f"dk={dep_keys}p={service.privileged}d={docker_config}")

                # Add the service to the list of services being scaled
                with self.profiles_lock:
                    if name not in self.profiles:
                        self.log.info(
                            f"Adding "
                            f"{f'privileged {service.name}' if service.privileged else service.name}"
                            " to scaling")
                        self.add_service(
                            ServiceProfile(
                                name=name,
                                min_instances=default_settings.min_instances,
                                growth=default_settings.growth,
                                shrink=default_settings.shrink,
                                config_blob=config_blob,
                                dependency_blobs=dependency_blobs,
                                backlog=default_settings.backlog,
                                max_instances=service.licence_count,
                                container_config=docker_config,
                                queue=get_service_queue(name, self.redis),
                                # Give service an extra 30 seconds to upload results
                                shutdown_seconds=service.timeout + 30,
                                privileged=service.privileged))

                    # Update RAM, CPU, licence requirements for running services
                    else:
                        profile = self.profiles[name]
                        profile.max_instances = service.licence_count
                        profile.privileged = service.privileged

                        for dependency_name, dependency_blob in dependency_blobs.items(
                        ):
                            if profile.dependency_blobs[
                                    dependency_name] != dependency_blob:
                                self.log.info(
                                    f"Updating deployment information for {name}/{dependency_name}"
                                )
                                profile.dependency_blobs[
                                    dependency_name] = dependency_blob
                                self.controller.start_stateful_container(
                                    service_name=service.name,
                                    container_name=dependency_name,
                                    spec=dependency_config[dependency_name],
                                    labels={'dependency_for': service.name},
                                    change_key=dependency_blob)

                        if profile.config_blob != config_blob:
                            self.log.info(
                                f"Updating deployment information for {name}")
                            profile.container_config = docker_config
                            profile.config_blob = config_blob
                            self.controller.restart(profile)
                            self.log.info(
                                f"Deployment information for {name} replaced")

        except Exception:
            self.log.exception(
                f"Error applying service settings from: {service.name}")
            self.handle_service_error(service.name)

    @elasticapm.capture_span(span_type=APM_SPAN_TYPE)
    def stop_service(self, name: str, current_stage: ServiceStage):
        if current_stage != ServiceStage.Off:
            # Disable this service's dependencies
            self.controller.stop_containers(labels={'dependency_for': name})

            # Mark this service as not running in the shared record
            self._service_stage_hash.set(name, ServiceStage.Off)

        # Stop any running disabled services
        if name in self.profiles or self.controller.get_target(name) > 0:
            self.log.info(f'Removing {name} from scaling')
            with self.profiles_lock:
                self.profiles.pop(name, None)
            self.controller.set_target(name, 0)

    def update_scaling(self):
        """Check if we need to scale any services up or down."""
        pool = Pool()
        while self.sleep(SCALE_INTERVAL):
            with apm_span(self.apm_client, 'update_scaling'):
                # Figure out what services are expected to be running and how many
                with elasticapm.capture_span('read_profiles'):
                    with self.profiles_lock:
                        all_profiles: dict[str,
                                           ServiceProfile] = copy.deepcopy(
                                               self.profiles)
                    raw_targets = self.controller.get_targets()
                    targets = {
                        _p.name: raw_targets.get(_p.name, 0)
                        for _p in all_profiles.values()
                    }

                for name, profile in all_profiles.items():
                    self.log.debug(f'{name}')
                    self.log.debug(
                        f'Instances \t{profile.min_instances} < {profile.desired_instances} | '
                        f'{targets[name]} < {profile.max_instances}')
                    self.log.debug(
                        f'Pressure \t{profile.shrink_threshold} < '
                        f'{profile.pressure} < {profile.growth_threshold}')

                #
                #   1.  Any processes that want to release resources can always be approved first
                #
                with pool:
                    for name, profile in all_profiles.items():
                        if targets[name] > profile.desired_instances:
                            self.log.info(
                                f"{name} wants less resources changing allocation "
                                f"{targets[name]} -> {profile.desired_instances}"
                            )
                            pool.call(self.controller.set_target, name,
                                      profile.desired_instances)
                            targets[name] = profile.desired_instances

                #
                #   2.  Any processes that aren't reaching their min_instances target must be given
                #       more resources before anyone else is considered.
                #
                    for name, profile in all_profiles.items():
                        if targets[name] < profile.min_instances:
                            self.log.info(
                                f"{name} isn't meeting minimum allocation "
                                f"{targets[name]} -> {profile.min_instances}")
                            pool.call(self.controller.set_target, name,
                                      profile.min_instances)
                            targets[name] = profile.min_instances

                #
                #   3.  Try to estimate available resources, and based on some metric grant the
                #       resources to each service that wants them. While this free memory
                #       pool might be spread across many nodes, we are going to treat it like
                #       it is one big one, and let the orchestration layer sort out the details.
                #

                # Recalculate the amount of free resources expanding the total quantity by the overallocation
                free_cpu, total_cpu = self.controller.cpu_info()
                used_cpu = total_cpu - free_cpu
                free_cpu = total_cpu * self.config.core.scaler.cpu_overallocation - used_cpu

                free_memory, total_memory = self.controller.memory_info()
                used_memory = total_memory - free_memory
                free_memory = total_memory * self.config.core.scaler.memory_overallocation - used_memory

                #
                def trim(prof: list[ServiceProfile]):
                    prof = [
                        _p for _p in prof
                        if _p.desired_instances > targets[_p.name]
                    ]
                    drop = [
                        _p for _p in prof
                        if _p.cpu > free_cpu or _p.ram > free_memory
                    ]
                    if drop:
                        summary = {_p.name: (_p.cpu, _p.ram) for _p in drop}
                        self.log.debug(
                            f"Can't make more because not enough resources {summary}"
                        )
                    prof = [
                        _p for _p in prof
                        if _p.cpu <= free_cpu and _p.ram <= free_memory
                    ]
                    return prof

                remaining_profiles: list[ServiceProfile] = trim(
                    list(all_profiles.values()))
                # The target values up until now should be in sync with the container orchestrator
                # create a copy, so we can track which ones change in the following loop
                old_targets = dict(targets)

                while remaining_profiles:
                    # TODO do we need to add balancing metrics other than 'least running' for this? probably
                    remaining_profiles.sort(key=lambda _p: targets[_p.name])

                    # Add one for the profile at the bottom
                    free_memory -= remaining_profiles[
                        0].container_config.ram_mb
                    free_cpu -= remaining_profiles[
                        0].container_config.cpu_cores
                    targets[remaining_profiles[0].name] += 1

                    # Take out any services that should be happy now
                    remaining_profiles = trim(remaining_profiles)

                # Apply those adjustments we have made back to the controller
                with elasticapm.capture_span('write_targets'):
                    with pool:
                        for name, value in targets.items():
                            if name not in self.profiles:
                                # A service was probably added/removed while we were
                                # in the middle of this function
                                continue
                            self.profiles[name].target_instances = value
                            old = old_targets[name]
                            if value != old:
                                self.log.info(
                                    f"Scaling service {name}: {old} -> {value}"
                                )
                                pool.call(self.controller.set_target, name,
                                          value)

    @elasticapm.capture_span(span_type=APM_SPAN_TYPE)
    def handle_service_error(self, service_name: str):
        """Handle an error occurring in the *analysis* service.

        Errors for core systems should simply be logged, and a best effort to continue made.

        For analysis services, ignore the error a few times, then disable the service.
        """
        with self.error_count_lock:
            try:
                self.error_count[service_name].append(time.time())
            except KeyError:
                self.error_count[service_name] = [time.time()]

            self.error_count[service_name] = [
                _t for _t in self.error_count[service_name]
                if _t >= time.time() - ERROR_EXPIRY_TIME
            ]

            if len(self.error_count[service_name]) >= MAXIMUM_SERVICE_ERRORS:
                self.log.warning(
                    f"Scaler has encountered too many errors trying to load {service_name}. "
                    "The service will be permanently disabled...")
                if self.datastore.service_delta.update(service_name, [
                    (self.datastore.service_delta.UPDATE_SET, 'enabled', False)
                ]):
                    # Raise awareness to other components by sending an event for the service
                    self.service_event_sender.send(service_name, {
                        'operation': Operation.Modified,
                        'name': service_name
                    })
                del self.error_count[service_name]

    def sync_metrics(self):
        """Check if there are any pub-sub messages we need."""
        while self.sleep(METRIC_SYNC_INTERVAL):
            with apm_span(self.apm_client, 'sync_metrics'):
                # Pull service metrics from redis
                service_data = self.status_table.items()
                for host, (service, state, time_limit) in service_data.items():
                    # If an entry hasn't expired, take it into account
                    if time.time() < time_limit:
                        self.state.update(
                            service=service,
                            host=host,
                            throughput=0,
                            busy_seconds=METRIC_SYNC_INTERVAL
                            if state == ServiceStatus.Running else 0)

                    # If an entry expired a while ago, the host is probably not in use any more
                    if time.time() > time_limit + 600:
                        self.status_table.pop(host)

                # Download the current targets in the orchestrator while not holding the lock
                with self.profiles_lock:
                    targets = {
                        name: profile.target_instances
                        for name, profile in self.profiles.items()
                    }

                # Check the set of services that might be sitting at zero instances, and if it is, we need to
                # manually check if it is offline
                export_interval = self.config.core.metrics.export_interval

                with self.profiles_lock:
                    queues = [
                        profile.queue for profile in self.profiles.values()
                        if profile.queue
                    ]
                    lengths_list = pq_length(*queues)
                    lengths = {_q: _l for _q, _l in zip(queues, lengths_list)}

                    for profile_name, profile in self.profiles.items():
                        queue_length = lengths.get(profile.queue, 0)

                        # Pull out statistics from the metrics regularization
                        update = self.state.read(profile_name)
                        if update:
                            delta = time.time() - profile.last_update
                            profile.update(delta=delta,
                                           backlog=queue_length,
                                           **update)

                        # Check if we expect no messages, if so pull the queue length ourselves
                        # since there is no heartbeat
                        if targets.get(
                                profile_name
                        ) == 0 and profile.desired_instances == 0 and profile.queue:

                            if queue_length > 0:
                                self.log.info(
                                    f"Service at zero instances has messages: "
                                    f"{profile.name} ({queue_length} in queue)"
                                )
                            profile.update(delta=export_interval,
                                           instances=0,
                                           backlog=queue_length,
                                           duty_cycle=profile.high_duty_cycle)

    def _timeout_kill(self, service, container):
        with apm_span(self.apm_client, 'timeout_kill'):
            self.controller.stop_container(service, container)
            self.status_table.pop(container)

    def process_timeouts(self):
        with concurrent.futures.ThreadPoolExecutor(10) as pool:
            futures = []

            while self.running:
                message = self.scaler_timeout_queue.pop(blocking=True,
                                                        timeout=1)
                if not message:
                    continue

                with apm_span(self.apm_client, 'process_timeouts'):
                    # Process new messages
                    self.log.info(
                        f"Killing service container: {message['container']} running: {message['service']}"
                    )
                    futures.append(
                        pool.submit(self._timeout_kill, message['service'],
                                    message['container']))

                    # Process finished
                    finished = [_f for _f in futures if _f.done()]
                    futures = [_f for _f in futures if _f not in finished]
                    for _f in finished:
                        exception = _f.exception()
                        if exception is not None:
                            self.log.error(
                                f"Exception trying to stop timed out service container: {exception}"
                            )

    def export_metrics(self):
        while self.sleep(self.config.logging.export_interval):
            with apm_span(self.apm_client, 'export_metrics'):
                service_metrics = {}
                with self.profiles_lock:
                    for service_name, profile in self.profiles.items():
                        service_metrics[service_name] = {
                            'running': profile.running_instances,
                            'target': profile.target_instances,
                            'minimum': profile.min_instances,
                            'maximum': profile.instance_limit,
                            'dynamic_maximum': profile.max_instances,
                            'queue': profile.queue_length,
                            'duty_cycle': profile.duty_cycle,
                            'pressure': profile.pressure
                        }

                for service_name, metrics in service_metrics.items():
                    export_metrics_once(service_name,
                                        Status,
                                        metrics,
                                        host=HOSTNAME,
                                        counter_type='scaler_status',
                                        config=self.config,
                                        redis=self.redis)

                memory, memory_total = self.controller.memory_info()
                cpu, cpu_total = self.controller.cpu_info()
                metrics = {
                    'memory_total': memory_total,
                    'cpu_total': cpu_total,
                    'memory_free': memory,
                    'cpu_free': cpu
                }
                export_metrics_once('scaler',
                                    Metrics,
                                    metrics,
                                    host=HOSTNAME,
                                    counter_type='scaler',
                                    config=self.config,
                                    redis=self.redis)

    def log_container_events(self):
        """The service status table may have references to containers that have crashed. Try to remove them all."""
        while self.sleep(CONTAINER_EVENTS_LOG_INTERVAL):
            with apm_span(self.apm_client, 'log_container_events'):
                for message in self.controller.new_events():
                    self.log.warning("Container Event :: " + message)
class ScalerServer(CoreBase):
    def __init__(self,
                 config=None,
                 datastore=None,
                 redis=None,
                 redis_persist=None):
        super().__init__('assemblyline.scaler',
                         config=config,
                         datastore=datastore,
                         redis=redis,
                         redis_persist=redis_persist)

        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)
        self.error_count = {}
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         host=self.redis,
                                         ttl=30 * 60)

        labels = {
            'app': 'assemblyline',
            'section': 'service',
        }

        if KUBERNETES_AL_CONFIG:
            self.log.info(
                f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}"
            )
            self.controller = KubernetesController(
                logger=self.log,
                prefix='alsvc_',
                labels=labels,
                namespace=NAMESPACE,
                priority='al-service-priority')
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_CONFIGMAP:
                self.controller.config_mount(
                    'classification-config',
                    config_map=CLASSIFICATION_CONFIGMAP,
                    key=CLASSIFICATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/classification.yml')
        else:
            self.log.info("Loading Docker cluster interface.")
            self.controller = DockerController(
                logger=self.log,
                prefix=NAMESPACE,
                cpu_overallocation=self.config.core.scaler.cpu_overallocation,
                memory_overallocation=self.config.core.scaler.
                memory_overallocation,
                labels=labels)
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_HOST_PATH:
                self.controller.global_mounts.append(
                    (CLASSIFICATION_HOST_PATH,
                     '/etc/assemblyline/classification.yml'))

        self.profiles: Dict[str, ServiceProfile] = {}

        # Prepare a single threaded scheduler
        self.state = collection.Collection(
            period=self.config.core.metrics.export_interval)
        self.scheduler = sched.scheduler()
        self.scheduler_stopped = threading.Event()

    def add_service(self, profile: ServiceProfile):
        profile.desired_instances = max(
            self.controller.get_target(profile.name), profile.min_instances)
        profile.running_instances = profile.desired_instances
        self.log.debug(
            f'Starting service {profile.name} with a target of {profile.desired_instances}'
        )
        profile.last_update = time.time()
        self.profiles[profile.name] = profile
        self.controller.add_profile(profile)

    def try_run(self):
        # Do an initial call to the main methods, who will then be registered with the scheduler
        self.sync_services()
        self.sync_metrics()
        self.update_scaling()
        self.expire_errors()
        self.process_timeouts()
        self.export_metrics()
        self.flush_service_status()
        self.log_container_events()
        self.heartbeat()

        # Run as long as we need to
        while self.running:
            delay = self.scheduler.run(False)
            time.sleep(min(delay, 2))
        self.scheduler_stopped.set()

    def stop(self):
        super().stop()
        self.scheduler_stopped.wait(5)
        self.controller.stop()

    def heartbeat(self):
        """Periodically touch a file on disk.

        Since tasks are run serially, the delay between touches will be the maximum of
        HEARTBEAT_INTERVAL and the longest running task.
        """
        if self.config.logging.heartbeat_file:
            self.scheduler.enter(HEARTBEAT_INTERVAL, 0, self.heartbeat)
            super().heartbeat()

    def sync_services(self):
        self.scheduler.enter(SERVICE_SYNC_INTERVAL, 0, self.sync_services)
        default_settings = self.config.core.scaler.service_defaults
        image_variables = defaultdict(str)
        image_variables.update(self.config.services.image_variables)
        current_services = set(self.profiles.keys())
        discovered_services = []

        # Get all the service data
        for service in self.datastore.list_all_services(full=True):
            service: Service = service
            name = service.name
            stage = self.get_service_stage(service.name)
            discovered_services.append(name)

            # noinspection PyBroadException
            try:
                if service.enabled and stage == ServiceStage.Off:
                    # Enable this service's dependencies
                    self.controller.prepare_network(
                        service.name,
                        service.docker_config.allow_internet_access)
                    for _n, dependency in service.dependencies.items():
                        self.controller.start_stateful_container(
                            service_name=service.name,
                            container_name=_n,
                            spec=dependency,
                            labels={'dependency_for': service.name})

                    # Move to the next service stage
                    if service.update_config and service.update_config.wait_for_update:
                        self._service_stage_hash.set(name, ServiceStage.Update)
                    else:
                        self._service_stage_hash.set(name,
                                                     ServiceStage.Running)

                if not service.enabled:
                    self.stop_service(service.name, stage)
                    continue

                # Check that all enabled services are enabled
                if service.enabled and stage == ServiceStage.Running:
                    # Compute a hash of service properties not include in the docker config, that
                    # should still result in a service being restarted when changed
                    config_hash = hash(str(sorted(service.config.items())))
                    config_hash = hash(
                        (config_hash, str(service.submission_params)))

                    # Build the docker config for the service, we are going to either create it or
                    # update it so we need to know what the current configuration is either way
                    docker_config = service.docker_config
                    docker_config.image = Template(
                        docker_config.image).safe_substitute(image_variables)
                    set_keys = set(var.name
                                   for var in docker_config.environment)
                    for var in default_settings.environment:
                        if var.name not in set_keys:
                            docker_config.environment.append(var)

                    # Add the service to the list of services being scaled
                    if name not in self.profiles:
                        self.log.info(f'Adding {service.name} to scaling')
                        self.add_service(
                            ServiceProfile(
                                name=name,
                                min_instances=default_settings.min_instances,
                                growth=default_settings.growth,
                                shrink=default_settings.shrink,
                                config_hash=config_hash,
                                backlog=default_settings.backlog,
                                max_instances=service.licence_count,
                                container_config=docker_config,
                                queue=get_service_queue(name, self.redis),
                                shutdown_seconds=service.timeout +
                                30,  # Give service an extra 30 seconds to upload results
                            ))

                    # Update RAM, CPU, licence requirements for running services
                    else:
                        profile = self.profiles[name]

                        if profile.container_config != docker_config or profile.config_hash != config_hash:
                            self.log.info(
                                f"Updating deployment information for {name}")
                            profile.container_config = docker_config
                            profile.config_hash = config_hash
                            self.controller.restart(profile)
                            self.log.info(
                                f"Deployment information for {name} replaced")

                        if service.licence_count == 0:
                            profile._max_instances = float('inf')
                        else:
                            profile._max_instances = service.licence_count
            except Exception:
                self.log.exception(
                    f"Error applying service settings from: {service.name}")
                self.handle_service_error(service.name)

        # Find any services we have running, that are no longer in the database and remove them
        for stray_service in current_services - set(discovered_services):
            stage = self.get_service_stage(stray_service)
            self.stop_service(stray_service, stage)

    def stop_service(self, name, current_stage):

        if current_stage != ServiceStage.Off:
            # Disable this service's dependencies
            self.controller.stop_containers(labels={'dependency_for': name})

            # Mark this service as not running in the shared record
            self._service_stage_hash.set(name, ServiceStage.Off)

        # Stop any running disabled services
        if name in self.profiles or self.controller.get_target(name) > 0:
            self.log.info(f'Removing {name} from scaling')
            self.controller.set_target(name, 0)
            self.profiles.pop(name, None)

    def update_scaling(self):
        """Check if we need to scale any services up or down."""
        self.scheduler.enter(SCALE_INTERVAL, 0, self.update_scaling)
        try:
            # Figure out what services are expected to be running and how many
            profiles: List[ServiceProfile] = list(self.profiles.values())
            targets = {
                _p.name: self.controller.get_target(_p.name)
                for _p in profiles
            }

            for name, profile in self.profiles.items():
                self.log.debug(f'{name}')
                self.log.debug(
                    f'Instances \t{profile.min_instances} < {profile.desired_instances} | '
                    f'{targets[name]} < {profile.max_instances}')
                self.log.debug(
                    f'Pressure \t{profile.shrink_threshold} < {profile.pressure} < {profile.growth_threshold}'
                )

            #
            #   1.  Any processes that want to release resources can always be approved first
            #
            for name, profile in self.profiles.items():
                if targets[name] > profile.desired_instances:
                    self.log.info(
                        f"{name} wants less resources changing allocation "
                        f"{targets[name]} -> {profile.desired_instances}")
                    self.controller.set_target(name, profile.desired_instances)
                    targets[name] = profile.desired_instances
                if not self.running:
                    return

            #
            #   2.  Any processes that aren't reaching their min_instances target must be given
            #       more resources before anyone else is considered.
            #
            for name, profile in self.profiles.items():
                if targets[name] < profile.min_instances:
                    self.log.info(
                        f"{name} isn't meeting minimum allocation "
                        f"{targets[name]} -> {profile.min_instances}")
                    self.controller.set_target(name, profile.min_instances)
                    targets[name] = profile.min_instances

            #
            #   3.  Try to estimate available resources, and based on some metric grant the
            #       resources to each service that wants them. While this free memory
            #       pool might be spread across many nodes, we are going to treat it like
            #       it is one big one, and let the orchestration layer sort out the details.
            #
            free_cpu = self.controller.free_cpu()
            free_memory = self.controller.free_memory()

            #
            def trim(prof: List[ServiceProfile]):
                prof = [
                    _p for _p in prof
                    if _p.desired_instances > targets[_p.name]
                ]
                drop = [
                    _p for _p in prof
                    if _p.cpu > free_cpu or _p.ram > free_memory
                ]
                if drop:
                    drop = {_p.name: (_p.cpu, _p.ram) for _p in drop}
                    self.log.debug(
                        f"Can't make more because not enough resources {drop}")
                prof = [
                    _p for _p in prof
                    if _p.cpu <= free_cpu and _p.ram <= free_memory
                ]
                return prof

            profiles = trim(profiles)

            while profiles:
                # TODO do we need to add balancing metrics other than 'least running' for this? probably
                if True:
                    profiles.sort(
                        key=lambda _p: self.controller.get_target(_p.name))

                # Add one for the profile at the bottom
                free_memory -= profiles[0].container_config.ram_mb
                free_cpu -= profiles[0].container_config.cpu_cores
                targets[profiles[0].name] += 1

                # profiles = [_p for _p in profiles if _p.desired_instances > targets[_p.name]]
                # profiles = [_p for _p in profiles if _p.cpu < free_cpu and _p.ram < free_memory]
                profiles = trim(profiles)

            # Apply those adjustments we have made back to the controller
            for name, value in targets.items():
                old = self.controller.get_target(name)
                if value != old:
                    self.log.info(f"Scaling service {name}: {old} -> {value}")
                    self.controller.set_target(name, value)
                if not self.running:
                    return

        except ServiceControlError as error:
            self.log.exception("Error while scaling services.")
            self.handle_service_error(error.service_name)

    def handle_service_error(self, service_name):
        """Handle an error occurring in the *analysis* service.

        Errors for core systems should simply be logged, and a best effort to continue made.

        For analysis services, ignore the error a few times, then disable the service.
        """
        self.error_count[service_name] = self.error_count.get(service_name,
                                                              0) + 1

        if self.error_count[service_name] >= MAXIMUM_SERVICE_ERRORS:
            self.datastore.service_delta.update(
                service_name,
                [(self.datastore.service_delta.UPDATE_SET, 'enabled', False)])
            del self.error_count[service_name]

    def sync_metrics(self):
        """Check if there are any pubsub messages we need."""
        self.scheduler.enter(METRIC_SYNC_INTERVAL, 3, self.sync_metrics)

        # Pull service metrics from redis
        service_data = self.status_table.items()
        for host, (service, state, time_limit) in service_data.items():
            # If an entry hasn't expired, take it into account
            if time.time() < time_limit:
                self.state.update(service=service,
                                  host=host,
                                  throughput=0,
                                  busy_seconds=METRIC_SYNC_INTERVAL
                                  if state == ServiceStatus.Running else 0)

            # If an entry expired a while ago, the host is probably not in use any more
            if time.time() > time_limit + 600:
                self.status_table.pop(host)

        # Check the set of services that might be sitting at zero instances, and if it is, we need to
        # manually check if it is offline
        export_interval = self.config.core.metrics.export_interval
        for profile_name, profile in self.profiles.items():
            # Pull out statistics from the metrics regularization
            update = self.state.read(profile_name)
            if update:
                delta = time.time() - profile.last_update
                profile.update(delta=delta,
                               backlog=profile.queue.length(),
                               **update)

            # Check if we expect no messages, if so pull the queue length ourselves since there is no heartbeat
            if self.controller.get_target(
                    profile_name
            ) == 0 and profile.desired_instances == 0 and profile.queue:
                queue_length = profile.queue.length()
                if queue_length > 0:
                    self.log.info(f"Service at zero instances has messages: "
                                  f"{profile.name} ({queue_length} in queue)")
                profile.update(delta=export_interval,
                               instances=0,
                               backlog=queue_length,
                               duty_cycle=profile.target_duty_cycle)

        # TODO maybe find another way of implementing this that is less aggressive
        # for profile_name, profile in self.profiles.items():
        #     # In the case that there should actually be instances running, but we haven't gotten
        #     # any heartbeat messages we might be waiting for a container that can't start properly
        #     if self.services.controller.get_target(profile_name) > 0:
        #         if time.time() - profile.last_update > profile.shutdown_seconds:
        #             self.log.error(f"Starting service {profile_name} has timed out "
        #                            f"({time.time() - profile.last_update} > {profile.shutdown_seconds} seconds)")
        #
        #             # Disable the the service
        #             self.datastore.service_delta.update(profile_name, [
        #                 (self.datastore.service_delta.UPDATE_SET, 'enabled', False)
        #             ])

    def expire_errors(self):
        self.scheduler.enter(ERROR_EXPIRY_INTERVAL, 0, self.expire_errors)
        self.error_count = {
            name: err - 1
            for name, err in self.error_count.items() if err > 1
        }

    def process_timeouts(self):
        self.scheduler.enter(PROCESS_TIMEOUT_INTERVAL, 0,
                             self.process_timeouts)
        while True:
            message = self.scaler_timeout_queue.pop(blocking=False)
            if not message:
                break
            # noinspection PyBroadException
            try:
                self.log.info(
                    f"Killing service container: {message['container']} running: {message['service']}"
                )
                self.controller.stop_container(message['service'],
                                               message['container'])
            except Exception:
                self.log.exception(
                    f"Exception trying to stop timed out service container: {message}"
                )

    def export_metrics(self):
        self.scheduler.enter(self.config.logging.export_interval, 0,
                             self.export_metrics)
        for service_name, profile in self.profiles.items():
            metrics = {
                'running': profile.running_instances,
                'target': profile.desired_instances,
                'minimum': profile.min_instances,
                'maximum': profile.instance_limit,
                'dynamic_maximum': profile.max_instances,
                'queue': profile.queue_length,
                'duty_cycle': profile.duty_cycle,
                'pressure': profile.pressure
            }
            export_metrics_once(service_name,
                                Status,
                                metrics,
                                host=HOSTNAME,
                                counter_type='scaler-status',
                                config=self.config,
                                redis=self.redis)

        memory, memory_total = self.controller.memory_info()
        cpu, cpu_total = self.controller.cpu_info()
        metrics = {
            'memory_total': memory_total,
            'cpu_total': cpu_total,
            'memory_free': memory,
            'cpu_free': cpu
        }

        export_metrics_once('scaler',
                            Metrics,
                            metrics,
                            host=HOSTNAME,
                            counter_type='scaler',
                            config=self.config,
                            redis=self.redis)

    def flush_service_status(self):
        """The service status table may have references to containers that have crashed. Try to remove them all."""
        self.scheduler.enter(SERVICE_STATUS_FLUSH, 0,
                             self.flush_service_status)

        # Pull all container names
        names = set(self.controller.get_running_container_names())

        # Get the names we have status for
        for hostname in self.status_table.keys():
            if hostname not in names:
                self.status_table.pop(hostname)

    def log_container_events(self):
        """The service status table may have references to containers that have crashed. Try to remove them all."""
        self.scheduler.enter(CONTAINER_EVENTS_LOG_INTERVAL, 0,
                             self.log_container_events)

        for message in self.controller.new_events():
            self.log.warning("Container Event :: " + message)
Esempio n. 21
0
    def service_finished(self,
                         sid: str,
                         result_key: str,
                         result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=result.response.service_name,
            sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing successful results."
            )
            return
        task = ServiceTask(task)

        # Check if the service is a candidate for dynamic recursion prevention
        if not task.ignore_dynamic_recursion_prevention:
            service_info = self.service_data.get(result.response.service_name,
                                                 None)
            if service_info and service_info.category == "Dynamic Analysis":
                # TODO: This should be done in lua because it can introduce race condition in the future
                #       but in the meantime it will remain this way while we can confirm it work as expected
                submission = self.active_submissions.get(sid)
                submission['submission']['params']['services'][
                    'runtime_excluded'].append(result.response.service_name)
                self.active_submissions.set(sid, submission)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key,
                                     {"expiry_ts": result.archive_ts})
        else:
            with Lock(f"lock-{result_key}", 5, self.redis):
                old = self.ds.result.get(result_key)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                self.ds.result.save(result_key, result)

        # Let the logs know we have received a result for this task
        if result.drop_file:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key} but processing will stop after this service."
            )
        else:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key}")

        # Store the result object and mark the service finished in the global table
        process_table = DispatchHash(task.sid, self.redis)
        remaining, duplicate = process_table.finish(
            task.fileinfo.sha256, task.service_name, result_key,
            result.result.score, result.classification, result.drop_file)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')
        if duplicate:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already "
                f"completed in the global processing table.")
            return

        # Push the result tags into redis
        new_tags = []
        for section in result.result.sections:
            new_tags.extend(tag_dict_to_list(section.tags.as_primitives()))
        if new_tags:
            tag_set = ExpiringSet(get_tag_set_name(
                sid=task.sid, file_hash=task.fileinfo.sha256),
                                  host=self.redis)
            tag_set.add(*new_tags)

        # Update the temporary data table for this file
        temp_data_hash = ExpiringHash(get_temporary_submission_data_name(
            sid=task.sid, file_hash=task.fileinfo.sha256),
                                      host=self.redis)
        for key, value in (temporary_data or {}).items():
            temp_data_hash.set(key, value)

        # Send the extracted files to the dispatcher
        depth_limit = self.config.submission.max_extraction_depth
        new_depth = task.depth + 1
        if new_depth < depth_limit:
            # Prepare the temporary data from the parent to build the temporary data table for
            # these newly extract files
            parent_data = dict(temp_data_hash.items())

            for extracted_data in result.response.extracted:
                if not process_table.add_file(
                        extracted_data.sha256,
                        task.max_files,
                        parent_hash=task.fileinfo.sha256):
                    if parent_data:
                        child_hash_name = get_temporary_submission_data_name(
                            task.sid, extracted_data.sha256)
                        ExpiringHash(child_hash_name,
                                     host=self.redis).multi_set(parent_data)

                    self._dispatching_error(
                        task, process_table,
                        Error({
                            'archive_ts': result.archive_ts,
                            'expiry_ts': result.expiry_ts,
                            'response': {
                                'message':
                                f"Too many files extracted for submission {task.sid} "
                                f"{extracted_data.sha256} extracted by "
                                f"{task.service_name} will be dropped",
                                'service_name':
                                task.service_name,
                                'service_tool_version':
                                result.response.service_tool_version,
                                'service_version':
                                result.response.service_version,
                                'status':
                                'FAIL_NONRECOVERABLE'
                            },
                            'sha256': extracted_data.sha256,
                            'type': 'MAX FILES REACHED'
                        }))
                    continue
                file_data = self.files.get(extracted_data.sha256)
                self.file_queue.push(
                    FileTask(
                        dict(sid=task.sid,
                             min_classification=task.min_classification.max(
                                 extracted_data.classification).value,
                             file_info=dict(
                                 magic=file_data.magic,
                                 md5=file_data.md5,
                                 mime=file_data.mime,
                                 sha1=file_data.sha1,
                                 sha256=file_data.sha256,
                                 size=file_data.size,
                                 type=file_data.type,
                             ),
                             depth=new_depth,
                             parent_hash=task.fileinfo.sha256,
                             max_files=task.max_files)).as_primitives())
        else:
            for extracted_data in result.response.extracted:
                self._dispatching_error(
                    task, process_table,
                    Error({
                        'archive_ts': result.archive_ts,
                        'expiry_ts': result.expiry_ts,
                        'response': {
                            'message':
                            f"{task.service_name} has extracted a file "
                            f"{extracted_data.sha256} beyond the depth limits",
                            'service_name':
                            result.response.service_name,
                            'service_tool_version':
                            result.response.service_tool_version,
                            'service_version':
                            result.response.service_version,
                            'status':
                            'FAIL_NONRECOVERABLE'
                        },
                        'sha256': extracted_data.sha256,
                        'type': 'MAX DEPTH REACHED'
                    }))

        # If the global table said that this was the last outstanding service,
        # send a message to the dispatchers.
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification.value,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w).push(msg)
Esempio n. 22
0
class DispatchClient:
    def __init__(self,
                 datastore=None,
                 redis=None,
                 redis_persist=None,
                 logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.timeout_watcher = WatcherClient(redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger(
            "assemblyline.dispatching.client")
        self.results = datastore.result
        self.errors = datastore.error
        self.files = datastore.file
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH,
                                               host=redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH,
                                          host=self.redis)
        self.service_data = cast(Dict[str, Service],
                                 CachedObject(self._get_services))

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def dispatch_submission(self,
                            submission: Submission,
                            completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(
            SubmissionTask(
                dict(
                    submission=submission,
                    completed_queue=completed_queue,
                )).as_primitives())

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        all_service_status = dispatch_hash.all_results()
        all_files = dispatch_hash.all_files()
        self.log.info(
            f"[{sid}] Listing outstanding services {len(all_files)} files "
            f"and {len(all_service_status)} entries found")

        output: Dict[str, int] = {}

        for file_hash in all_files:
            # The schedule might not be in the cache if the submission or file was just issued,
            # but it will be as soon as the file passes through the dispatcher
            schedule = dispatch_hash.schedules.get(file_hash)
            status_values = all_service_status.get(file_hash, {})

            # Go through the schedule stage by stage so we can react to drops
            # either we have a result and we don't need to count the file (but might drop it)
            # or we don't have a result, and we need to count that file
            while schedule:
                stage = schedule.pop(0)
                for service_name in stage:
                    status = status_values.get(service_name)
                    if status:
                        if status.drop:
                            schedule.clear()
                    else:
                        output[service_name] = output.get(service_name, 0) + 1

        return output

    def request_work(self,
                     worker_id,
                     service_name,
                     service_version,
                     timeout: float = 60,
                     blocking=True) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            try:
                return self._request_work(worker_id,
                                          service_name,
                                          service_version,
                                          blocking=blocking,
                                          timeout=remaining)
            except RetryRequestWork:
                remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version, timeout,
                      blocking) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(
                f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout))
        else:
            result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(
                f"{service_name}:{worker_id} no task returned: [empty message]"
            )
            return None
        task = ServiceTask(result)

        # If someone is supposed to be working on this task right now, we won't be able to add it
        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found"
            )

            process_table = DispatchHash(task.sid, self.redis)

            abandoned = process_table.dispatch_time(
                file_hash=task.fileinfo.sha256, service=task.service_name) == 0
            finished = process_table.finished(
                file_hash=task.fileinfo.sha256,
                service=task.service_name) is not None

            # A service might be re-dispatched as it finishes, when that is the case it can be marked as
            # both finished and dispatched, if that is the case, drop the dispatch from the table
            if finished and not abandoned:
                process_table.drop_dispatch(file_hash=task.fileinfo.sha256,
                                            service=task.service_name)

            if abandoned or finished:
                self.log.info(
                    f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete"
                )
                self.running_tasks.pop(task.key())
                raise RetryRequestWork()

            # Check if this task has reached the retry limit
            attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}',
                                          host=self.redis)
            total_attempts = attempt_record.increment(task.key())
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} "
                f"task attempt {total_attempts}/3")
            if total_attempts > 3:
                self.log.warning(
                    f"[{task.sid}/{task.fileinfo.sha256}] "
                    f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED "
                )
                error = Error(
                    dict(
                        archive_ts=now_as_iso(
                            self.config.datastore.ilm.days_until_archive * 24 *
                            60 * 60),
                        created='NOW',
                        expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                             60) if task.ttl else None,
                        response=dict(
                            message=
                            f'The number of retries has passed the limit.',
                            service_name=task.service_name,
                            service_version=service_version,
                            status='FAIL_NONRECOVERABLE',
                        ),
                        sha256=task.fileinfo.sha256,
                        type="TASK PRE-EMPTED",
                    ))
                error_key = error.build_key(task=task)
                self.service_failed(task.sid, error_key, error)
                export_metrics_once(service_name,
                                    Metrics,
                                    dict(fail_nonrecoverable=1),
                                    host=worker_id,
                                    counter_type='service')
                raise RetryRequestWork()

            # Get the service information
            service_data = self.service_data[task.service_name]
            self.timeout_watcher.touch_task(timeout=int(service_data.timeout),
                                            key=f'{task.sid}-{task.key()}',
                                            worker=worker_id,
                                            task_key=task.key())
            return task
        raise RetryRequestWork()

    def _dispatching_error(self, task, process_table, error):
        error_key = error.build_key(task=task)
        if process_table.add_error(error_key):
            self.errors.save(error_key, error)
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)

    def service_finished(self,
                         sid: str,
                         result_key: str,
                         result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=result.response.service_name,
            sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing successful results."
            )
            return
        task = ServiceTask(task)

        # Check if the service is a candidate for dynamic recursion prevention
        if not task.ignore_dynamic_recursion_prevention:
            service_info = self.service_data.get(result.response.service_name,
                                                 None)
            if service_info and service_info.category == "Dynamic Analysis":
                # TODO: This should be done in lua because it can introduce race condition in the future
                #       but in the meantime it will remain this way while we can confirm it work as expected
                submission = self.active_submissions.get(sid)
                submission['submission']['params']['services'][
                    'runtime_excluded'].append(result.response.service_name)
                self.active_submissions.set(sid, submission)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key,
                                     {"expiry_ts": result.archive_ts})
        else:
            with Lock(f"lock-{result_key}", 5, self.redis):
                old = self.ds.result.get(result_key)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                self.ds.result.save(result_key, result)

        # Let the logs know we have received a result for this task
        if result.drop_file:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key} but processing will stop after this service."
            )
        else:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key}")

        # Store the result object and mark the service finished in the global table
        process_table = DispatchHash(task.sid, self.redis)
        remaining, duplicate = process_table.finish(
            task.fileinfo.sha256, task.service_name, result_key,
            result.result.score, result.classification, result.drop_file)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')
        if duplicate:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already "
                f"completed in the global processing table.")
            return

        # Push the result tags into redis
        new_tags = []
        for section in result.result.sections:
            new_tags.extend(tag_dict_to_list(section.tags.as_primitives()))
        if new_tags:
            tag_set = ExpiringSet(get_tag_set_name(
                sid=task.sid, file_hash=task.fileinfo.sha256),
                                  host=self.redis)
            tag_set.add(*new_tags)

        # Update the temporary data table for this file
        temp_data_hash = ExpiringHash(get_temporary_submission_data_name(
            sid=task.sid, file_hash=task.fileinfo.sha256),
                                      host=self.redis)
        for key, value in (temporary_data or {}).items():
            temp_data_hash.set(key, value)

        # Send the extracted files to the dispatcher
        depth_limit = self.config.submission.max_extraction_depth
        new_depth = task.depth + 1
        if new_depth < depth_limit:
            # Prepare the temporary data from the parent to build the temporary data table for
            # these newly extract files
            parent_data = dict(temp_data_hash.items())

            for extracted_data in result.response.extracted:
                if not process_table.add_file(
                        extracted_data.sha256,
                        task.max_files,
                        parent_hash=task.fileinfo.sha256):
                    if parent_data:
                        child_hash_name = get_temporary_submission_data_name(
                            task.sid, extracted_data.sha256)
                        ExpiringHash(child_hash_name,
                                     host=self.redis).multi_set(parent_data)

                    self._dispatching_error(
                        task, process_table,
                        Error({
                            'archive_ts': result.archive_ts,
                            'expiry_ts': result.expiry_ts,
                            'response': {
                                'message':
                                f"Too many files extracted for submission {task.sid} "
                                f"{extracted_data.sha256} extracted by "
                                f"{task.service_name} will be dropped",
                                'service_name':
                                task.service_name,
                                'service_tool_version':
                                result.response.service_tool_version,
                                'service_version':
                                result.response.service_version,
                                'status':
                                'FAIL_NONRECOVERABLE'
                            },
                            'sha256': extracted_data.sha256,
                            'type': 'MAX FILES REACHED'
                        }))
                    continue
                file_data = self.files.get(extracted_data.sha256)
                self.file_queue.push(
                    FileTask(
                        dict(sid=task.sid,
                             min_classification=task.min_classification.max(
                                 extracted_data.classification).value,
                             file_info=dict(
                                 magic=file_data.magic,
                                 md5=file_data.md5,
                                 mime=file_data.mime,
                                 sha1=file_data.sha1,
                                 sha256=file_data.sha256,
                                 size=file_data.size,
                                 type=file_data.type,
                             ),
                             depth=new_depth,
                             parent_hash=task.fileinfo.sha256,
                             max_files=task.max_files)).as_primitives())
        else:
            for extracted_data in result.response.extracted:
                self._dispatching_error(
                    task, process_table,
                    Error({
                        'archive_ts': result.archive_ts,
                        'expiry_ts': result.expiry_ts,
                        'response': {
                            'message':
                            f"{task.service_name} has extracted a file "
                            f"{extracted_data.sha256} beyond the depth limits",
                            'service_name':
                            result.response.service_name,
                            'service_tool_version':
                            result.response.service_tool_version,
                            'service_version':
                            result.response.service_version,
                            'status':
                            'FAIL_NONRECOVERABLE'
                        },
                        'sha256': extracted_data.sha256,
                        'type': 'MAX DEPTH REACHED'
                    }))

        # If the global table said that this was the last outstanding service,
        # send a message to the dispatchers.
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification.value,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w).push(msg)

    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=error.response.service_name,
            sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(
            f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error."
        )
        remaining = -1
        # Mark the attempt to process the file over in the dispatch table
        process_table = DispatchHash(task.sid, self.redis)
        if error.response.status == "FAIL_RECOVERABLE":
            # Because the error is recoverable, we will not save it nor we will notify the user
            process_table.fail_recoverable(task.fileinfo.sha256,
                                           task.service_name)
        else:
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            remaining, _duplicate = process_table.fail_nonrecoverable(
                task.fileinfo.sha256, task.service_name, error_key)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')

        # Send a message to prompt the re-issue of the task if needed
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

    def setup_watch_queue(self, sid):
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        # Create a unique queue
        queue_name = reply_queue_name(prefix="D", suffix="WQ")
        watch_queue = NamedQueue(queue_name, ttl=30)
        watch_queue.push(
            WatchQueueMessage({
                'status': 'START'
            }).as_primitives())

        # Add the newly created queue to the list of queues for the given submission
        self._get_watcher_list(sid).add(queue_name)

        # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute)
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        if dispatch_hash.dispatch_count(
        ) == 0 and dispatch_hash.finished_count() == 0:
            # This table is empty? do we have this submission at all?
            submission = self.ds.submission.get(sid)
            if not submission or submission.state == 'completed':
                watch_queue.push(
                    WatchQueueMessage({
                        "status": "STOP"
                    }).as_primitives())
            else:
                # We do have a submission, remind the dispatcher to work on it
                self.submission_queue.push({'sid': sid})

        else:
            all_service_status = dispatch_hash.all_results()
            for status_values in all_service_status.values():
                for status in status_values.values():
                    if status.is_error:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "FAIL",
                                "cache_key": status.key
                            }).as_primitives())
                    else:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "OK",
                                "cache_key": status.key
                            }).as_primitives())

        return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
Esempio n. 23
0
    def _request_work(self, worker_id, service_name, service_version, timeout,
                      blocking) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(
                f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout))
        else:
            result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(
                f"{service_name}:{worker_id} no task returned: [empty message]"
            )
            return None
        task = ServiceTask(result)

        # If someone is supposed to be working on this task right now, we won't be able to add it
        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found"
            )

            process_table = DispatchHash(task.sid, self.redis)

            abandoned = process_table.dispatch_time(
                file_hash=task.fileinfo.sha256, service=task.service_name) == 0
            finished = process_table.finished(
                file_hash=task.fileinfo.sha256,
                service=task.service_name) is not None

            # A service might be re-dispatched as it finishes, when that is the case it can be marked as
            # both finished and dispatched, if that is the case, drop the dispatch from the table
            if finished and not abandoned:
                process_table.drop_dispatch(file_hash=task.fileinfo.sha256,
                                            service=task.service_name)

            if abandoned or finished:
                self.log.info(
                    f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete"
                )
                self.running_tasks.pop(task.key())
                raise RetryRequestWork()

            # Check if this task has reached the retry limit
            attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}',
                                          host=self.redis)
            total_attempts = attempt_record.increment(task.key())
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} "
                f"task attempt {total_attempts}/3")
            if total_attempts > 3:
                self.log.warning(
                    f"[{task.sid}/{task.fileinfo.sha256}] "
                    f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED "
                )
                error = Error(
                    dict(
                        archive_ts=now_as_iso(
                            self.config.datastore.ilm.days_until_archive * 24 *
                            60 * 60),
                        created='NOW',
                        expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                             60) if task.ttl else None,
                        response=dict(
                            message=
                            f'The number of retries has passed the limit.',
                            service_name=task.service_name,
                            service_version=service_version,
                            status='FAIL_NONRECOVERABLE',
                        ),
                        sha256=task.fileinfo.sha256,
                        type="TASK PRE-EMPTED",
                    ))
                error_key = error.build_key(task=task)
                self.service_failed(task.sid, error_key, error)
                export_metrics_once(service_name,
                                    Metrics,
                                    dict(fail_nonrecoverable=1),
                                    host=worker_id,
                                    counter_type='service')
                raise RetryRequestWork()

            # Get the service information
            service_data = self.service_data[task.service_name]
            self.timeout_watcher.touch_task(timeout=int(service_data.timeout),
                                            key=f'{task.sid}-{task.key()}',
                                            worker=worker_id,
                                            task_key=task.key())
            return task
        raise RetryRequestWork()
Esempio n. 24
0
    def __init__(self,
                 config=None,
                 datastore=None,
                 redis=None,
                 redis_persist=None):
        super().__init__('assemblyline.scaler',
                         config=config,
                         datastore=datastore,
                         redis=redis,
                         redis_persist=redis_persist)

        self.scaler_timeout_queue = NamedQueue(SCALER_TIMEOUT_QUEUE,
                                               host=self.redis_persist)
        self.error_count_lock = threading.Lock()
        self.error_count: dict[str, list[float]] = {}
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         host=self.redis,
                                         ttl=30 * 60)
        self.service_event_sender = EventSender('changes.services',
                                                host=self.redis)
        self.service_change_watcher = EventWatcher(
            self.redis, deserializer=ServiceChange.deserialize)
        self.service_change_watcher.register('changes.services.*',
                                             self._handle_service_change_event)

        core_env: dict[str, str] = {}
        # If we have privileged services, we must be able to pass the necessary environment variables for them to
        # function properly.
        for secret in re.findall(
                r'\${\w+}',
                open('/etc/assemblyline/config.yml',
                     'r').read()) + ['UI_SERVER']:
            env_name = secret.strip("${}")
            core_env[env_name] = os.environ[env_name]

        labels = {
            'app': 'assemblyline',
            'section': 'service',
            'privilege': 'service'
        }

        if self.config.core.scaler.additional_labels:
            labels.update({
                k: v
                for k, v in (
                    _l.split("=")
                    for _l in self.config.core.scaler.additional_labels)
            })

        if KUBERNETES_AL_CONFIG:
            self.log.info(
                f"Loading Kubernetes cluster interface on namespace: {NAMESPACE}"
            )
            self.controller = KubernetesController(
                logger=self.log,
                prefix='alsvc_',
                labels=labels,
                namespace=NAMESPACE,
                priority='al-service-priority',
                cpu_reservation=self.config.services.cpu_reservation,
                log_level=self.config.logging.log_level,
                core_env=core_env)
            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_CONFIGMAP:
                self.controller.config_mount(
                    'classification-config',
                    config_map=CLASSIFICATION_CONFIGMAP,
                    key=CLASSIFICATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/classification.yml')
            if CONFIGURATION_CONFIGMAP:
                self.controller.core_config_mount(
                    'assemblyline-config',
                    config_map=CONFIGURATION_CONFIGMAP,
                    key=CONFIGURATION_CONFIGMAP_KEY,
                    target_path='/etc/assemblyline/config.yml')
        else:
            self.log.info("Loading Docker cluster interface.")
            self.controller = DockerController(
                logger=self.log,
                prefix=NAMESPACE,
                labels=labels,
                log_level=self.config.logging.log_level,
                core_env=core_env)
            self._service_stage_hash.delete()

            if DOCKER_CONFIGURATION_PATH and DOCKER_CONFIGURATION_VOLUME:
                self.controller.core_mounts.append(
                    (DOCKER_CONFIGURATION_VOLUME, '/etc/assemblyline/'))

                with open(
                        os.path.join(DOCKER_CONFIGURATION_PATH, 'config.yml'),
                        'w') as handle:
                    yaml.dump(self.config.as_primitives(), handle)

                with open(
                        os.path.join(DOCKER_CONFIGURATION_PATH,
                                     'classification.yml'), 'w') as handle:
                    yaml.dump(get_classification().original_definition, handle)

            # If we know where to find it, mount the classification into the service containers
            if CLASSIFICATION_HOST_PATH:
                self.controller.global_mounts.append(
                    (CLASSIFICATION_HOST_PATH,
                     '/etc/assemblyline/classification.yml'))

        # Information about services
        self.profiles: dict[str, ServiceProfile] = {}
        self.profiles_lock = threading.RLock()

        # Prepare a single threaded scheduler
        self.state = collection.Collection(
            period=self.config.core.metrics.export_interval)
        self.stopping = threading.Event()
        self.main_loop_exit = threading.Event()

        # Load the APM connection if any
        self.apm_client = None
        if self.config.core.metrics.apm_server.server_url:
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="scaler")
class DispatchClient:
    def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger("assemblyline.dispatching.client")
        self.results = self.ds.result
        self.errors = self.ds.error
        self.files = self.ds.file
        self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)
        self.service_data = cast(Dict[str, Service], CachedObject(self._get_services))
        self.dispatcher_data = []
        self.dispatcher_data_age = 0.0
        self.dead_dispatchers = []

    @weak_lru(maxsize=128)
    def _get_queue_from_cache(self, name):
        return NamedQueue(name, host=self.redis, ttl=QUEUE_EXPIRY)

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def is_dispatcher(self, dispatcher_id) -> bool:
        if dispatcher_id in self.dead_dispatchers:
            return False
        if time.time() - self.dispatcher_data_age > 120 or dispatcher_id not in self.dispatcher_data:
            self.dispatcher_data = Dispatcher.all_instances(self.redis_persist)
            self.dispatcher_data_age = time.time()
        if dispatcher_id in self.dispatcher_data:
            return True
        else:
            self.dead_dispatchers.append(dispatcher_id)
            return False

    def dispatch_bundle(self, submission: Submission, results: Dict[str, Result],
                        file_infos: Dict[str, File], file_tree, errors: Dict[str, Error], completed_queue: str = None):
        """Insert a bundle into the dispatching system and continue scanning of its files

        Prerequisites:
            - Submission, results, file_infos and errors should already be saved in the datastore
            - Files should already be in the filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            results=results,
            file_infos=file_infos,
            file_tree=file_tree,
            errors=errors,
            completed_queue=completed_queue,
        ))

    def dispatch_submission(self, submission: Submission, completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            completed_queue=completed_queue,
        ))

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue")
            queue = NamedQueue(queue_name, host=self.redis, ttl=30)
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': LIST_OUTSTANDING,
                'payload_data': ListOutstanding({
                    'response_queue': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue.pop(timeout=30)
        return {}

    @elasticapm.capture_span(span_type='dispatch_client')
    def request_work(self, worker_id, service_name, service_version,
                     timeout: float = 60, blocking=True, low_priority=False) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            work = self._request_work(worker_id, service_name, service_version,
                                      blocking=blocking, timeout=remaining, low_priority=low_priority)
            if work or not blocking:
                return work
            remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version,
                      timeout, blocking, low_priority=False) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout), low_priority=low_priority)
        else:
            if low_priority:
                result = work_queue.unpush(1)
            else:
                result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(f"{service_name}:{worker_id} no task returned: [empty message]")
            return None
        task = ServiceTask(result)
        task.metadata['worker__'] = worker_id
        dispatcher = task.metadata['dispatcher__']

        if not self.is_dispatcher(dispatcher):
            self.log.info(f"{service_name}:{worker_id} no task returned: [task from dead dispatcher]")
            return None

        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found")
            start_queue = self._get_queue_from_cache(DISPATCH_START_EVENTS + dispatcher)
            start_queue.push((task.sid, task.fileinfo.sha256, service_name, worker_id))
            return task
        return None

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_finished(self, sid: str, result_key: str, result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(sid=sid, service_name=result.response.service_name, sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing successful results.")
            return
        task = ServiceTask(task)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts})
        else:
            while True:
                old, version = self.ds.result.get_if_exists(
                    result_key, archive_access=self.config.datastore.ilm.update_archive, version=True)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                try:
                    self.ds.result.save(result_key, result, version=version)
                    break
                except VersionConflictException as vce:
                    self.log.info(f"Retrying to save results due to version conflict: {str(vce)}")

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w, host=self.redis).push(msg)

        # Save the tags
        tags = []
        for section in result.result.sections:
            tags.extend(tag_dict_to_list(flatten(section.tags.as_primitives())))

        # Pull out file names if we have them
        file_names = {}
        for extracted_data in result.response.extracted:
            if extracted_data.name:
                file_names[extracted_data.sha256] = extracted_data.name

        #
        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        ex_ts = result.expiry_ts.strftime(DATEFORMAT) if result.expiry_ts else result.archive_ts.strftime(DATEFORMAT)
        result_queue.push({
            # 'service_task': task.as_primitives(),
            # 'result': result.as_primitives(),
            'sid': task.sid,
            'sha256': result.sha256,
            'service_name': task.service_name,
            'service_version': result.response.service_version,
            'service_tool_version': result.response.service_tool_version,
            'archive_ts': result.archive_ts.strftime(DATEFORMAT),
            'expiry_ts': ex_ts,
            'result_summary': {
                'key': result_key,
                'drop': result.drop_file,
                'score': result.result.score,
                'children': [r.sha256 for r in result.response.extracted],
            },
            'tags': tags,
            'extracted_names': file_names,
            'temporary_data': temporary_data
        })

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(sid=sid, service_name=error.response.service_name, sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error.")
        if error.response.status == "FAIL_NONRECOVERABLE":
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w, host=self.redis).push(msg)

        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        result_queue.push({
            'sid': task.sid,
            'service_task': task.as_primitives(),
            'error': error.as_primitives(),
            'error_key': error_key
        })

    def setup_watch_queue(self, sid: str) -> Optional[str]:
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="WQ")
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': CREATE_WATCH,
                'payload_data': CreateWatch({
                    'queue_name': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
Esempio n. 26
0
    def dispatch_file(self, task: FileTask):
        """ Handle a message describing a file to be processed.

        This file may be:
            - A new submission or extracted file.
            - A file that has just completed a stage of processing.
            - A file that has not completed a a stage of processing, but this
              call has been triggered by a timeout or similar.

        If the file is totally new, we will setup a dispatch table, and fill it in.

        Once we make/load a dispatch table, we will dispatch whichever group the table
        shows us hasn't been completed yet.

        When we dispatch to a service, we check if the task is already in the dispatch
        queue. If it isn't proceed normally. If it is, check that the service is still online.
        """
        # Read the message content
        file_hash = task.file_info.sha256
        active_task = self.active_submissions.get(task.sid)

        if active_task is None:
            self.log.warning(f"[{task.sid}] Untracked submission is being processed")
            return

        submission_task = SubmissionTask(active_task)
        submission = submission_task.submission

        # Refresh the watch on the submission, we are still working on it
        self.timeout_watcher.touch(key=task.sid, timeout=int(self.config.core.dispatcher.timeout),
                                   queue=SUBMISSION_QUEUE, message={'sid': task.sid})

        # Open up the file/service table for this submission
        dispatch_table = DispatchHash(task.sid, self.redis, fetch_results=True)

        # Load things that we will need to fill out the
        file_tags = ExpiringSet(task.get_tag_set_name(), host=self.redis)
        file_tags_data = file_tags.members()
        temporary_submission_data = ExpiringHash(task.get_temporary_submission_data_name(), host=self.redis)
        temporary_data = [dict(name=row[0], value=row[1]) for row in temporary_submission_data.items().items()]

        # Calculate the schedule for the file
        schedule = self.build_schedule(dispatch_table, submission, file_hash, task.file_info.type)
        started_stages = []

        # Go through each round of the schedule removing complete/failed services
        # Break when we find a stage that still needs processing
        outstanding = {}
        score = 0
        errors = 0
        while schedule and not outstanding:
            stage = schedule.pop(0)
            started_stages.append(stage)

            for service_name in stage:
                service = self.scheduler.services.get(service_name)
                if not service:
                    continue

                # Load the results, if there are no results, then the service must be dispatched later
                # Don't look at if it has been dispatched, as multiple dispatches are fine,
                # but missing a dispatch isn't.
                finished = dispatch_table.finished(file_hash, service_name)
                if not finished:
                    outstanding[service_name] = service
                    continue

                # If the service terminated in an error, count the error and continue
                if finished.is_error:
                    errors += 1
                    continue

                # if the service finished, count the score, and check if the file has been dropped
                score += finished.score
                if not submission.params.ignore_filtering and finished.drop:
                    schedule.clear()
                    if schedule:  # If there are still stages in the schedule, over write them for next time
                        dispatch_table.schedules.set(file_hash, started_stages)

        # Try to retry/dispatch any outstanding services
        if outstanding:
            self.log.info(f"[{task.sid}] File {file_hash} sent to services : {', '.join(list(outstanding.keys()))}")

            for service_name, service in outstanding.items():

                # Find the actual file name from the list of files in submission
                filename = None
                for file in submission.files:
                    if task.file_info.sha256 == file.sha256:
                        filename = file.name
                        break

                # Build the actual service dispatch message
                config = self.build_service_config(service, submission)
                service_task = ServiceTask(dict(
                    sid=task.sid,
                    metadata=submission.metadata,
                    min_classification=task.min_classification,
                    service_name=service_name,
                    service_config=config,
                    fileinfo=task.file_info,
                    filename=filename or task.file_info.sha256,
                    depth=task.depth,
                    max_files=task.max_files,
                    ttl=submission.params.ttl,
                    ignore_cache=submission.params.ignore_cache,
                    ignore_dynamic_recursion_prevention=submission.params.ignore_dynamic_recursion_prevention,
                    tags=file_tags_data,
                    temporary_submission_data=temporary_data,
                    deep_scan=submission.params.deep_scan,
                    priority=submission.params.priority,
                ))
                dispatch_table.dispatch(file_hash, service_name)
                queue = get_service_queue(service_name, self.redis)
                queue.push(service_task.priority, service_task.as_primitives())

        else:
            # There are no outstanding services, this file is done
            # clean up the tags
            file_tags.delete()

            # If there are no outstanding ANYTHING for this submission,
            # send a message to the submission dispatcher to finalize
            self.counter.increment('files_completed')
            if dispatch_table.all_finished():
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}' starting submission finalization.")
                self.submission_queue.push({'sid': submission.sid})
            else:
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}'. Other files are not finished.")
Esempio n. 27
0
class Dispatcher:

    def __init__(self, datastore, redis, redis_persist, logger, counter_name='dispatcher'):
        # Load the datastore collections that we are going to be using
        self.datastore: AssemblylineDatastore = datastore
        self.log: logging.Logger = logger
        self.submissions: Collection = datastore.submission
        self.results: Collection = datastore.result
        self.errors: Collection = datastore.error
        self.files: Collection = datastore.file

        # Create a config cache that will refresh config values periodically
        self.config: Config = forge.get_config()

        # Connect to all of our persistent redis structures
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Build some utility classes
        self.scheduler = Scheduler(datastore, self.config, self.redis)
        self.classification_engine = forge.get_classification()
        self.timeout_watcher = WatcherClient(self.redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self._nonper_other_queues = {}
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)

        # Publish counters to the metrics sink.
        self.counter = MetricsFactory(metrics_type='dispatcher', schema=Metrics, name=counter_name,
                                      redis=self.redis, config=self.config)

    def volatile_named_queue(self, name: str) -> NamedQueue:
        if name not in self._nonper_other_queues:
            self._nonper_other_queues[name] = NamedQueue(name, self.redis)
        return self._nonper_other_queues[name]

    def dispatch_submission(self, task: SubmissionTask):
        """
        Find any files associated with a submission and dispatch them if they are
        not marked as in progress. If all files are finished, finalize the submission.

        This version of dispatch submission doesn't verify each result, but assumes that
        the dispatch table has been kept up to date by other components.

        Preconditions:
            - File exists in the filestore and file collection in the datastore
            - Submission is stored in the datastore
        """
        submission = task.submission
        sid = submission.sid

        if not self.active_submissions.exists(sid):
            self.log.info(f"[{sid}] New submission received")
            self.active_submissions.add(sid, task.as_primitives())
        else:
            self.log.info(f"[{sid}] Received a pre-existing submission, check if it is complete")

        # Refresh the watch, this ensures that this function will be called again
        # if something goes wrong with one of the files, and it never gets invoked by dispatch_file.
        self.timeout_watcher.touch(key=sid, timeout=int(self.config.core.dispatcher.timeout),
                                   queue=SUBMISSION_QUEUE, message={'sid': sid})

        # Refresh the quota hold
        if submission.params.quota_item and submission.params.submitter:
            self.log.info(f"[{sid}] Submission will count towards {submission.params.submitter.upper()} quota")
            Hash('submissions-' + submission.params.submitter, self.redis_persist).add(sid, isotime.now_as_iso())

        # Open up the file/service table for this submission
        dispatch_table = DispatchHash(submission.sid, self.redis, fetch_results=True)
        file_parents = dispatch_table.file_tree()  # Load the file tree data as well

        # All the submission files, and all the file_tree files, to be sure we don't miss any incomplete children
        unchecked_hashes = [submission_file.sha256 for submission_file in submission.files]
        unchecked_hashes = list(set(unchecked_hashes) | set(file_parents.keys()))

        # Using the file tree we can recalculate the depth of any file
        depth_limit = self.config.submission.max_extraction_depth
        file_depth = depths_from_tree(file_parents)

        # Try to find all files, and extracted files, and create task objects for them
        # (we will need the file data anyway for checking the schedule later)
        max_files = len(submission.files) + submission.params.max_extracted
        unchecked_files = []  # Files that haven't been checked yet
        try:
            for sha, file_data in self.files.multiget(unchecked_hashes).items():
                unchecked_files.append(FileTask(dict(
                    sid=sid,
                    min_classification=task.submission.classification,
                    file_info=dict(
                        magic=file_data.magic,
                        md5=file_data.md5,
                        mime=file_data.mime,
                        sha1=file_data.sha1,
                        sha256=file_data.sha256,
                        size=file_data.size,
                        type=file_data.type,
                    ),
                    depth=file_depth.get(sha, 0),
                    max_files=max_files
                )))
        except MultiKeyError as missing:
            errors = []
            for file_sha in missing.keys:
                error = Error(dict(
                    archive_ts=submission.archive_ts,
                    expiry_ts=submission.expiry_ts,
                    response=dict(
                        message="Submission couldn't be completed due to missing file.",
                        service_name="dispatcher",
                        service_tool_version='4',
                        service_version='4',
                        status="FAIL_NONRECOVERABLE",
                    ),
                    sha256=file_sha,
                    type='UNKNOWN'
                ))
                error_key = error.build_key(service_tool_version=sid)
                self.datastore.error.save(error_key, error)
                errors.append(error_key)
            return self.cancel_submission(task, errors, file_parents)

        # Files that have already been encountered, but may or may not have been processed yet
        # encountered_files = {file.sha256 for file in submission.files}
        pending_files = {}  # Files that have not yet been processed

        # Track information about the results as we hit them
        file_scores: Dict[str, int] = {}

        # # Load the current state of the dispatch table in one go rather than one at a time in the loop
        prior_dispatches = dispatch_table.all_dispatches()

        # found should be added to the unchecked files if they haven't been encountered already
        for file_task in unchecked_files:
            sha = file_task.file_info.sha256
            schedule = self.build_schedule(dispatch_table, submission, sha, file_task.file_info.type)

            while schedule:
                stage = schedule.pop(0)
                for service_name in stage:
                    # Only active services should be in this dict, so if a service that was placed in the
                    # schedule is now missing it has been disabled or taken offline.
                    service = self.scheduler.services.get(service_name)
                    if not service:
                        continue

                    # If the service is still marked as 'in progress'
                    runtime = time.time() - prior_dispatches.get(sha, {}).get(service_name, 0)
                    if runtime < service.timeout:
                        pending_files[sha] = file_task
                        continue

                    # It hasn't started, has timed out, or is finished, see if we have a result
                    result_row = dispatch_table.finished(sha, service_name)

                    # No result found, mark the file as incomplete
                    if not result_row:
                        pending_files[sha] = file_task
                        continue

                    if not submission.params.ignore_filtering and result_row.drop:
                        schedule.clear()

                    # The process table is marked that a service has been abandoned due to errors
                    if result_row.is_error:
                        continue

                    # Collect information about the result
                    file_scores[sha] = file_scores.get(sha, 0) + result_row.score

        # Using the file tree find the most shallow parent of the given file
        def lowest_parent(_sha):
            # A root file won't have any parents in the dict
            if _sha not in file_parents or None in file_parents[_sha]:
                return None
            return min((file_depth.get(parent, depth_limit), parent) for parent in file_parents[_sha])[1]

        # Filter out things over the depth limit
        pending_files = {sha: ft for sha, ft in pending_files.items() if ft.depth < depth_limit}

        # Filter out files based on the extraction limits
        pending_files = {sha: ft for sha, ft in pending_files.items()
                         if dispatch_table.add_file(sha, max_files, lowest_parent(sha))}

        # If there are pending files, then at least one service, on at least one
        # file isn't done yet, and hasn't been filtered by any of the previous few steps
        # poke those files
        if pending_files:
            self.log.debug(f"[{sid}] Dispatching {len(pending_files)} files: {list(pending_files.keys())}")
            for file_task in pending_files.values():
                self.file_queue.push(file_task.as_primitives())
        else:
            self.log.debug(f"[{sid}] Finalizing submission.")
            max_score = max(file_scores.values()) if file_scores else 0  # Submissions with no results have no score
            self.finalize_submission(task, max_score, file_scores.keys())

    def _cleanup_submission(self, task: SubmissionTask, file_list: List[str]):
        """Clean up code that is the same for canceled and finished submissions"""
        submission = task.submission
        sid = submission.sid

        # Erase the temporary data which may have accumulated during processing
        for file_hash in file_list:
            hash_name = get_temporary_submission_data_name(sid, file_hash=file_hash)
            ExpiringHash(hash_name, host=self.redis).delete()

        if submission.params.quota_item and submission.params.submitter:
            self.log.info(f"[{sid}] Submission no longer counts toward {submission.params.submitter.upper()} quota")
            Hash('submissions-' + submission.params.submitter, self.redis_persist).pop(sid)

        if task.completed_queue:
            self.volatile_named_queue(task.completed_queue).push(submission.as_primitives())

        # Send complete message to any watchers.
        watcher_list = ExpiringSet(make_watcher_list_name(sid), host=self.redis)
        for w in watcher_list.members():
            NamedQueue(w).push(WatchQueueMessage({'status': 'STOP'}).as_primitives())

        # Clear the timeout watcher
        watcher_list.delete()
        self.timeout_watcher.clear(sid)
        self.active_submissions.pop(sid)

        # Count the submission as 'complete' either way
        self.counter.increment('submissions_completed')

    def cancel_submission(self, task: SubmissionTask, errors, file_list):
        """The submission is being abandoned, delete everything, write failed state."""
        submission = task.submission
        sid = submission.sid

        # Pull down the dispatch table and clear it from redis
        dispatch_table = DispatchHash(submission.sid, self.redis)
        dispatch_table.delete()

        submission.classification = submission.params.classification
        submission.error_count = len(errors)
        submission.errors = errors
        submission.state = 'failed'
        submission.times.completed = isotime.now_as_iso()
        self.submissions.save(sid, submission)

        self._cleanup_submission(task, file_list)
        self.log.error(f"[{sid}] Failed")

    def finalize_submission(self, task: SubmissionTask, max_score, file_list):
        """All of the services for all of the files in this submission have finished or failed.

        Update the records in the datastore, and flush the working data from redis.
        """
        submission = task.submission
        sid = submission.sid

        # Pull down the dispatch table and clear it from redis
        dispatch_table = DispatchHash(submission.sid, self.redis)
        all_results = dispatch_table.all_results()
        errors = dispatch_table.all_extra_errors()
        dispatch_table.delete()

        # Sort the errors out of the results
        results = []
        for row in all_results.values():
            for status in row.values():
                if status.is_error:
                    errors.append(status.key)
                elif status.bucket == 'result':
                    results.append(status.key)
                else:
                    self.log.warning(f"[{sid}] Unexpected service output bucket: {status.bucket}/{status.key}")

        submission.classification = submission.params.classification
        submission.error_count = len(errors)
        submission.errors = errors
        submission.file_count = len(file_list)
        submission.results = results
        submission.max_score = max_score
        submission.state = 'completed'
        submission.times.completed = isotime.now_as_iso()
        self.submissions.save(sid, submission)

        self._cleanup_submission(task, file_list)
        self.log.info(f"[{sid}] Completed; files: {len(file_list)} results: {len(results)} "
                      f"errors: {len(errors)} score: {max_score}")

    def dispatch_file(self, task: FileTask):
        """ Handle a message describing a file to be processed.

        This file may be:
            - A new submission or extracted file.
            - A file that has just completed a stage of processing.
            - A file that has not completed a a stage of processing, but this
              call has been triggered by a timeout or similar.

        If the file is totally new, we will setup a dispatch table, and fill it in.

        Once we make/load a dispatch table, we will dispatch whichever group the table
        shows us hasn't been completed yet.

        When we dispatch to a service, we check if the task is already in the dispatch
        queue. If it isn't proceed normally. If it is, check that the service is still online.
        """
        # Read the message content
        file_hash = task.file_info.sha256
        active_task = self.active_submissions.get(task.sid)

        if active_task is None:
            self.log.warning(f"[{task.sid}] Untracked submission is being processed")
            return

        submission_task = SubmissionTask(active_task)
        submission = submission_task.submission

        # Refresh the watch on the submission, we are still working on it
        self.timeout_watcher.touch(key=task.sid, timeout=int(self.config.core.dispatcher.timeout),
                                   queue=SUBMISSION_QUEUE, message={'sid': task.sid})

        # Open up the file/service table for this submission
        dispatch_table = DispatchHash(task.sid, self.redis, fetch_results=True)

        # Load things that we will need to fill out the
        file_tags = ExpiringSet(task.get_tag_set_name(), host=self.redis)
        file_tags_data = file_tags.members()
        temporary_submission_data = ExpiringHash(task.get_temporary_submission_data_name(), host=self.redis)
        temporary_data = [dict(name=row[0], value=row[1]) for row in temporary_submission_data.items().items()]

        # Calculate the schedule for the file
        schedule = self.build_schedule(dispatch_table, submission, file_hash, task.file_info.type)
        started_stages = []

        # Go through each round of the schedule removing complete/failed services
        # Break when we find a stage that still needs processing
        outstanding = {}
        score = 0
        errors = 0
        while schedule and not outstanding:
            stage = schedule.pop(0)
            started_stages.append(stage)

            for service_name in stage:
                service = self.scheduler.services.get(service_name)
                if not service:
                    continue

                # Load the results, if there are no results, then the service must be dispatched later
                # Don't look at if it has been dispatched, as multiple dispatches are fine,
                # but missing a dispatch isn't.
                finished = dispatch_table.finished(file_hash, service_name)
                if not finished:
                    outstanding[service_name] = service
                    continue

                # If the service terminated in an error, count the error and continue
                if finished.is_error:
                    errors += 1
                    continue

                # if the service finished, count the score, and check if the file has been dropped
                score += finished.score
                if not submission.params.ignore_filtering and finished.drop:
                    schedule.clear()
                    if schedule:  # If there are still stages in the schedule, over write them for next time
                        dispatch_table.schedules.set(file_hash, started_stages)

        # Try to retry/dispatch any outstanding services
        if outstanding:
            self.log.info(f"[{task.sid}] File {file_hash} sent to services : {', '.join(list(outstanding.keys()))}")

            for service_name, service in outstanding.items():

                # Find the actual file name from the list of files in submission
                filename = None
                for file in submission.files:
                    if task.file_info.sha256 == file.sha256:
                        filename = file.name
                        break

                # Build the actual service dispatch message
                config = self.build_service_config(service, submission)
                service_task = ServiceTask(dict(
                    sid=task.sid,
                    metadata=submission.metadata,
                    min_classification=task.min_classification,
                    service_name=service_name,
                    service_config=config,
                    fileinfo=task.file_info,
                    filename=filename or task.file_info.sha256,
                    depth=task.depth,
                    max_files=task.max_files,
                    ttl=submission.params.ttl,
                    ignore_cache=submission.params.ignore_cache,
                    ignore_dynamic_recursion_prevention=submission.params.ignore_dynamic_recursion_prevention,
                    tags=file_tags_data,
                    temporary_submission_data=temporary_data,
                    deep_scan=submission.params.deep_scan,
                    priority=submission.params.priority,
                ))
                dispatch_table.dispatch(file_hash, service_name)
                queue = get_service_queue(service_name, self.redis)
                queue.push(service_task.priority, service_task.as_primitives())

        else:
            # There are no outstanding services, this file is done
            # clean up the tags
            file_tags.delete()

            # If there are no outstanding ANYTHING for this submission,
            # send a message to the submission dispatcher to finalize
            self.counter.increment('files_completed')
            if dispatch_table.all_finished():
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}' starting submission finalization.")
                self.submission_queue.push({'sid': submission.sid})
            else:
                self.log.info(f"[{task.sid}] Finished processing file '{file_hash}'. Other files are not finished.")

    def build_schedule(self, dispatch_hash: DispatchHash, submission: Submission,
                       file_hash: str, file_type: str) -> List[List[str]]:
        """Rather than rebuilding the schedule every time we see a file, build it once and cache in redis."""
        cached_schedule = dispatch_hash.schedules.get(file_hash)
        if not cached_schedule:
            # Get the schedule for that file type based on the submission parameters
            obj_schedule = self.scheduler.build_schedule(submission, file_type)
            # The schedule built by the scheduling tool has the service objects, we just want the names for now
            cached_schedule = [list(stage.keys()) for stage in obj_schedule]
            dispatch_hash.schedules.add(file_hash, cached_schedule)
        return cached_schedule

    @classmethod
    def build_service_config(cls, service: Service, submission: Submission) -> Dict[str, str]:
        """Prepare the service config that will be used downstream.

        v3 names: get_service_params get_config_data
        """
        # Load the default service config
        params = {x.name: x.default for x in service.submission_params}

        # Over write it with values from the submission
        if service.name in submission.params.service_spec:
            params.update(submission.params.service_spec[service.name])
        return params