class Counters(object):
    def __init__(self, prefix="counter", host=None, port=None, track_counters=False):
        self.c = get_client(host, port, False)
        self.prefix = prefix
        if track_counters:
            self.tracker = Hash("c-tracker-%s" % prefix, host=host, port=port)
        else:
            self.tracker = None

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.delete()

    def inc(self, name, value=1, track_id=None):
        if self.tracker:
            self.tracker.add(track_id or name, now_as_iso())
        return retry_call(self.c.incr, "%s-%s" % (self.prefix, name), value)

    def dec(self, name, value=1, track_id=None):
        if self.tracker:
            self.tracker.pop(str(track_id or name))
        return retry_call(self.c.decr, "%s-%s" % (self.prefix, name), value)

    def get_queues_sizes(self):
        out = {}
        for queue in retry_call(self.c.keys, "%s-*" % self.prefix):
            queue_size = int(retry_call(self.c.get, queue))
            out[queue] = queue_size

        return {k.decode('utf-8'): v for k, v in out.items()}

    def get_queues(self):
        return [k.decode('utf-8') for k in retry_call(self.c.keys, "%s-*" % self.prefix)]

    def ready(self):
        try:
            self.c.ping()
        except ConnectionError:
            return False

        return True

    def reset_queues(self):
        if self.tracker:
            self.tracker.delete()
        for queue in retry_call(self.c.keys, "%s-*" % self.prefix):
            retry_call(self.c.set, queue, "0")

    def delete(self):
        if self.tracker:
            self.tracker.delete()
        for queue in retry_call(self.c.keys, "%s-*" % self.prefix):
            retry_call(self.c.delete, queue)
Ejemplo n.º 2
0
    def do_ui(self, args):
        """
        Perform UI related operations

        Usage:
            ui show_sessions [username]
            ui clear_sessions [username]

        actions:
            show_sessions      show all active sessions
            clear_sessions     Removes all active sessions

        Parameters:
            username           User use to filter sessions
                               [optional]

        Examples:
            # Clear sessions for user bob
            ui clear_sessions bob

            # Show all current sessions
            ui show_sessions
        """
        valid_func = ['clear_sessions', 'show_sessions']
        args = self._parse_args(args)

        if len(args) not in [1, 2]:
            self._print_error("Wrong number of arguments for restore command.")
            return

        func = args[0]
        if func not in valid_func:
            self._print_error(f"Invalid action '{func}' for ui command.")
            return

        if func == 'clear_sessions':
            username = None
            if len(args) == 2:
                username = args[1]

            flsk_sess = Hash("flask_sessions",
                             host=config.core.redis.nonpersistent.host,
                             port=config.core.redis.nonpersistent.port)

            if not username:
                flsk_sess.delete()
                self.logger.info("All sessions where cleared.")
            else:
                for k, v in flsk_sess.items().items():
                    if v.get('username', None) == username:
                        self.logger.info(f"Removing session: {v}")
                        flsk_sess.pop(k)

                self.logger.info(
                    f"All sessions for user '{username}' removed.")
        if func == 'show_sessions':
            username = None
            if len(args) == 2:
                username = args[1]

            flsk_sess = Hash("flask_sessions",
                             host=config.core.redis.nonpersistent.host,
                             port=config.core.redis.nonpersistent.port)

            if not username:
                for k, v in flsk_sess.items().items():
                    self.logger.info(f"{v.get('username', None)} => {v}")
            else:
                self.logger.info(f'Showing sessions for user {username}:')
                for k, v in flsk_sess.items().items():
                    if v.get('username', None) == username:
                        self.logger.info(f"    {v}")
Ejemplo n.º 3
0
class Ingester:
    """Internal interface to the ingestion queues."""
    def __init__(self,
                 datastore,
                 logger,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester'):
        self.datastore = datastore
        self.log = logger

        # Cache the user groups
        self.cache_lock = threading.RLock(
        )  # TODO are middle man instances single threaded now?
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Create a config cache that will refresh config values periodically
        self.config = forge.CachedObject(forge.get_config)

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value = constants.PRIORITIES
        self.priority_range = constants.PRIORITY_RANGES
        self.threshold_value = constants.PRIORITY_THRESHOLDS

        # Connect to the redis servers
        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )
        self.persistent_redis = persistent_redis or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.persistent_redis)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(_completeq_name, self.redis)

        # Internal. Dropped entries are placed on this queue.
        # self.drop_queue = NamedQueue('m-drop', self.persistent_redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME,
                                       self.persistent_redis)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.persistent_redis)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.persistent_redis)

        # Internal, timeout watch queue
        self.timeout_queue = PriorityQueue('m-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.persistent_redis)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.persistent_redis)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)

    def get_groups_from_user(self, username: str) -> List[str]:
        # Reset the group cache at the top of each hour
        if time.time() // HOUR_IN_SECONDS > self._user_groups_reset:
            self._user_groups = {}
            self._user_groups_reset = time.time() // HOUR_IN_SECONDS

        # Get the groups for this user if not known
        if username not in self._user_groups:
            user_data = self.datastore.user.get(username)
            if user_data:
                self._user_groups[username] = user_data.groups
            else:
                self._user_groups[username] = []
        return self._user_groups[username]

    def ingest(self, task: IngestTask):
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Task received for processing"
        )
        # Load a snapshot of ingest parameters as of right now.
        max_file_size = self.config.submission.max_file_size
        param = task.params

        self.counter.increment('bytes_ingested', increment_by=task.file_size)
        self.counter.increment('submissions_ingested')

        if any(len(file.sha256) != 64 for file in task.submission.files):
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped")
            self.send_notification(task,
                                   failure="Invalid sha256",
                                   logfunc=self.log.warning)
            return

        # Clean up metadata strings, since we may delete some, iterate on a copy of the keys
        for key in list(task.submission.metadata.keys()):
            value = task.submission.metadata[key]
            meta_size = len(value)
            if meta_size > self.config.submission.max_metadata_length:
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] '
                    f'Removing {key} from metadata because value is too big')
                task.submission.metadata.pop(key)

        if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop:
            task.failure = f"File too large ({task.file_size} > {max_file_size})"
            self._notify_drop(task)
            self.counter.increment('skipped')
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] {task.failure}")
            return

        # Set the groups from the user, if they aren't already set
        if not task.params.groups:
            task.params.groups = self.get_groups_from_user(
                task.params.submitter)

        # Check if this file is already being processed
        pprevious, previous, score = None, False, None
        if not param.ignore_cache:
            pprevious, previous, score, _ = self.check(task)

        # Assign priority.
        low_priority = self.is_low_priority(task)

        priority = param.priority
        if priority < 0:
            priority = self.priority_value['medium']

            if score is not None:
                priority = self.priority_value['low']
                for level, threshold in self.threshold_value.items():
                    if score >= threshold:
                        priority = self.priority_value[level]
                        break
            elif low_priority:
                priority = self.priority_value['low']

        # Reduce the priority by an order of magnitude for very old files.
        current_time = now()
        if priority and self.expired(
                current_time - task.submission.time.timestamp(), 0):
            priority = (priority / 10) or 1

        param.priority = priority

        # Do this after priority has been assigned.
        # (So we don't end up dropping the resubmission).
        if previous:
            self.counter.increment('duplicates')
            self.finalize(pprevious, previous, score, task)
            return

        if self.drop(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped")
            return

        if self.is_whitelisted(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted")
            return

        self.unique_queue.push(priority, task.as_primitives())

    def check(self, task: IngestTask):
        key = self.stamp_filescore_key(task)

        with self.cache_lock:
            result = self.cache.get(key, None)

        if result:
            self.counter.increment('cache_hit_local')
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Local cache hit')
        else:
            result = self.datastore.filescore.get(key)
            if result:
                self.counter.increment('cache_hit')
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] Remote cache hit')
            else:
                self.counter.increment('cache_miss')
                return None, False, None, key

            with self.cache_lock:
                self.cache[key] = result

        current_time = now()
        age = current_time - result.time
        errors = result.errors

        if self.expired(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired"
            )
            self.counter.increment('cache_expired')
            self.cache.pop(key, None)
            self.datastore.filescore.delete(key)
            return None, False, None, key
        elif self.stale(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale"
            )
            self.counter.increment('cache_stale')
            return None, False, result.score, key

        return result.psid, result.sid, result.score, key

    def stale(self, delta: float, errors: int):
        if errors:
            return delta >= self.config.core.ingester.incomplete_stale_after_seconds
        else:
            return delta >= self.config.core.ingester.stale_after_seconds

    @staticmethod
    def stamp_filescore_key(task: IngestTask, sha256=None):
        if not sha256:
            sha256 = task.submission.files[0].sha256

        key = task.scan_key

        if not key:
            key = task.params.create_filescore_key(sha256)
            task.scan_key = key

        return key

    def completed(self, sub):
        """Invoked when notified that a submission has completed."""
        # There is only one file in the submissions we have made
        sha256 = sub.files[0].sha256
        scan_key = sub.params.create_filescore_key(sha256)
        raw = self.scanning.pop(scan_key)

        psid = sub.params.psid
        score = sub.max_score
        sid = sub.sid

        if not raw:
            # Some other worker has already popped the scanning queue?
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission completed twice")
            return scan_key

        task = IngestTask(raw)
        task.submission.sid = sid

        errors = sub.error_count
        file_count = sub.file_count
        self.counter.increment('submissions_completed')
        self.counter.increment('files_completed', increment_by=file_count)
        self.counter.increment('bytes_completed', increment_by=task.file_size)

        with self.cache_lock:
            fs = self.cache[scan_key] = FileScore({
                'expiry_ts':
                now(self.config.core.ingester.cache_dtl * 24 * 60 * 60),
                'errors':
                errors,
                'psid':
                psid,
                'score':
                score,
                'sid':
                sid,
                'time':
                now(),
            })
            self.datastore.filescore.save(scan_key, fs)

        self.finalize(psid, sid, score, task)

        def exhaust() -> Iterable[IngestTask]:
            while True:
                res = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                               blocking=False)
                if res is None:
                    break
                res = IngestTask(res)
                res.submission.sid = sid
                yield res

        # You may be tempted to remove the assignment to dups and use the
        # value directly in the for loop below. That would be a mistake.
        # The function finalize may push on the duplicate queue which we
        # are pulling off and so condensing those two lines creates a
        # potential infinite loop.
        dups = [dup for dup in exhaust()]
        for dup in dups:
            self.finalize(psid, sid, score, dup)

        return scan_key

    def send_notification(self, task: IngestTask, failure=None, logfunc=None):
        if logfunc is None:
            logfunc = self.log.info

        if failure:
            task.failure = failure

        failure = task.failure
        if failure:
            logfunc("%s: %s", failure, str(task.json()))

        if not task.submission.notification.queue:
            return

        note_queue = _notification_queue_prefix + task.submission.notification.queue
        threshold = task.submission.notification.threshold

        if threshold is not None and task.score is not None and task.score < threshold:
            return

        q = self.notification_queues.get(note_queue, None)
        if not q:
            self.notification_queues[note_queue] = q = NamedQueue(
                note_queue, self.persistent_redis)
        q.push(task.as_primitives())

    def expired(self, delta: float, errors) -> bool:
        if errors:
            return delta >= self.config.core.ingester.incomplete_expire_after_seconds
        else:
            return delta >= self.config.core.ingester.expire_after

    def drop(self, task: IngestTask) -> bool:
        priority = task.params.priority
        sample_threshold = self.config.core.ingester.sampling_at

        dropped = False
        if priority <= _min_priority:
            dropped = True
        else:
            for level, rng in self.priority_range.items():
                if rng[0] <= priority <= rng[1] and level in sample_threshold:
                    dropped = must_drop(self.unique_queue.count(*rng),
                                        sample_threshold[level])
                    break

            if not dropped:
                if task.file_size > self.config.submission.max_file_size or task.file_size == 0:
                    dropped = True

        if task.params.never_drop or not dropped:
            return False

        task.failure = 'Skipped'
        self._notify_drop(task)
        self.counter.increment('skipped')
        return True

    def _notify_drop(self, task: IngestTask):
        self.send_notification(task)

        c12n = task.params.classification
        expiry = now_as_iso(86400)
        sha256 = task.submission.files[0].sha256

        self.datastore.save_or_freshen_file(sha256, {'sha256': sha256},
                                            expiry,
                                            c12n,
                                            redis=self.redis)

    def is_whitelisted(self, task: IngestTask):
        reason, hit = self.get_whitelist_verdict(self.whitelist, task)
        hit = {x: dotdump(safe_str(y)) for x, y in hit.items()}
        sha256 = task.submission.files[0].sha256

        if not reason:
            with self.whitelisted_lock:
                reason = self.whitelisted.get(sha256, None)
                if reason:
                    hit = 'cached'

        if reason:
            if hit != 'cached':
                with self.whitelisted_lock:
                    self.whitelisted[sha256] = reason

            task.failure = "Whitelisting due to reason %s (%s)" % (dotdump(
                safe_str(reason)), hit)
            self._notify_drop(task)

            self.counter.increment('whitelisted')

        return reason

    def submit(self, task: IngestTask):
        self.submit_client.submit(
            submission_obj=task.submission,
            completed_queue=_completeq_name,
        )

        self.timeout_queue.push(int(now(_max_time)), task.scan_key)
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis"
        )

    def retry(self, task, scan_key, ex):
        current_time = now()

        retries = task.retries + 1

        if retries > _max_retries:
            trace = ''
            if ex:
                trace = ': ' + get_stacktrace_info(ex)
            self.log.error(
                f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        elif self.expired(current_time - task.ingest_time.timestamp(), 0):
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        else:
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})'
            )
            task.retries = retries
            self.retry_queue.push(int(now(_retry_delay)), task.json())

    def finalize(self, psid, sid, score, task: IngestTask):
        self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed")
        if psid:
            task.params.psid = psid
        task.score = score
        task.submission.sid = sid

        selected = task.params.services.selected
        resubmit_to = task.params.services.resubmit

        resubmit_selected = determine_resubmit_selected(selected, resubmit_to)
        will_resubmit = resubmit_selected and should_resubmit(score)
        if will_resubmit:
            task.extended_scan = 'submitted'
            task.params.psid = None

        if self.is_alert(task, score):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Notifying alerter "
                f"to {'update' if will_resubmit else 'create'} an alert")
            self.alert_queue.push(task.as_primitives())

        self.send_notification(task)

        if will_resubmit:
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis"
            )
            task.params.psid = sid
            task.submission.sid = None
            task.params.services.resubmit = []
            task.scan_key = None
            task.params.services.selected = resubmit_selected

            self.unique_queue.push(task.params.priority, task.as_primitives())

    def is_alert(self, task: IngestTask, score):
        if not task.params.generate_alert:
            return False

        if score < self.threshold_value['critical']:
            return False

        return True
Ejemplo n.º 4
0
class ServiceUpdater(CoreBase):
    def __init__(self,
                 redis_persist=None,
                 redis=None,
                 logger=None,
                 datastore=None):
        super().__init__('assemblyline.service.updater',
                         logger=logger,
                         datastore=datastore,
                         redis_persist=redis_persist,
                         redis=redis)

        if not FILE_UPDATE_DIRECTORY:
            raise RuntimeError(
                "The updater process must be run within the orchestration environment, "
                "the update volume must be mounted, and the path to the volume must be "
                "set in the environment variable FILE_UPDATE_DIRECTORY. Setting "
                "FILE_UPDATE_DIRECTORY directly may be done for testing.")

        # The directory where we want working temporary directories to be created.
        # Building our temporary directories in the persistent update volume may
        # have some performance down sides, but may help us run into fewer docker FS overlay
        # cleanup issues. Try to flush it out every time we start. This service should
        # be a singleton anyway.
        self.temporary_directory = os.path.join(FILE_UPDATE_DIRECTORY, '.tmp')
        shutil.rmtree(self.temporary_directory, ignore_errors=True)
        os.makedirs(self.temporary_directory)

        self.container_update = Hash('container-update', self.redis_persist)
        self.services = Hash('service-updates', self.redis_persist)
        self.latest_service_tags = Hash('service-tags', self.redis_persist)
        self.running_updates: Dict[str, Thread] = {}

        # Prepare a single threaded scheduler
        self.scheduler = sched.scheduler()

        #
        if 'KUBERNETES_SERVICE_HOST' in os.environ and NAMESPACE:
            self.controller = KubernetesUpdateInterface(
                prefix='alsvc_',
                namespace=NAMESPACE,
                priority_class='al-core-priority')
        else:
            self.controller = DockerUpdateInterface()

    def sync_services(self):
        """Download the service list and make sure our settings are up to date"""
        self.scheduler.enter(SERVICE_SYNC_INTERVAL, 0, self.sync_services)
        existing_services = (set(self.services.keys())
                             | set(self.container_update.keys())
                             | set(self.latest_service_tags.keys()))
        discovered_services = []

        # Get all the service data
        for service in self.datastore.list_all_services(full=True):
            discovered_services.append(service.name)

            # Ensure that any disabled services are not being updated
            if not service.enabled and self.services.exists(service.name):
                self.log.info(f"Service updates disabled for {service.name}")
                self.services.pop(service.name)

            if not service.enabled:
                continue

            # Ensure that any enabled services with an update config are being updated
            stage = self.get_service_stage(service.name)
            record = self.services.get(service.name)

            if stage in UPDATE_STAGES and service.update_config:
                # Stringify and hash the the current update configuration
                config_hash = hash(
                    json.dumps(service.update_config.as_primitives()))

                # If we can update, but there is no record, create one
                if not record:
                    self.log.info(
                        f"Service updates enabled for {service.name}")
                    self.services.add(
                        service.name,
                        dict(
                            next_update=now_as_iso(),
                            previous_update=now_as_iso(-10**10),
                            config_hash=config_hash,
                            sha256=None,
                        ))
                else:
                    # If there is a record, check that its configuration hash is still good
                    # If an update is in progress, it may overwrite this, but we will just come back
                    # and reapply this again in the iteration after that
                    if record.get('config_hash', None) != config_hash:
                        record['next_update'] = now_as_iso()
                        record['config_hash'] = config_hash
                        self.services.set(service.name, record)

            if stage == ServiceStage.Update:
                if (record and record.get('sha256', None)
                        is not None) or not service.update_config:
                    self._service_stage_hash.set(service.name,
                                                 ServiceStage.Running)

        # Remove services we have locally or in redis that have been deleted from the database
        for stray_service in existing_services - set(discovered_services):
            self.log.info(f"Service updates disabled for {stray_service}")
            self.services.pop(stray_service)
            self._service_stage_hash.pop(stray_service)
            self.container_update.pop(stray_service)
            self.latest_service_tags.pop(stray_service)

    def container_updates(self):
        """Go through the list of services and check what are the latest tags for it"""
        self.scheduler.enter(UPDATE_CHECK_INTERVAL, 0, self.container_updates)
        for service_name, update_data in self.container_update.items().items():
            self.log.info(
                f"Service {service_name} is being updated to version {update_data['latest_tag']}..."
            )

            # Load authentication params
            username = None
            password = None
            auth = update_data['auth'] or {}
            if auth:
                username = auth.get('username', None)
                password = auth.get('password', None)

            try:
                self.controller.launch(
                    name=service_name,
                    docker_config=DockerConfig(
                        dict(allow_internet_access=True,
                             registry_username=username,
                             registry_password=password,
                             cpu_cores=1,
                             environment=[],
                             image=update_data['image'],
                             ports=[])),
                    mounts=[],
                    env={
                        "SERVICE_TAG":
                        update_data['latest_tag'],
                        "SERVICE_API_HOST":
                        os.environ.get('SERVICE_API_HOST',
                                       "http://al_service_server:5003"),
                        "REGISTER_ONLY":
                        'true'
                    },
                    network='al_registration',
                    blocking=True)

                latest_tag = update_data['latest_tag'].replace('stable', '')

                service_key = f"{service_name}_{latest_tag}"

                if self.datastore.service.get_if_exists(service_key):
                    operations = [(self.datastore.service_delta.UPDATE_SET,
                                   'version', latest_tag)]
                    if self.datastore.service_delta.update(
                            service_name, operations):
                        # Update completed, cleanup
                        self.log.info(
                            f"Service {service_name} update successful!")
                    else:
                        self.log.error(
                            f"Service {service_name} has failed to update because it cannot set "
                            f"{latest_tag} as the new version. Update procedure cancelled..."
                        )
                else:
                    self.log.error(
                        f"Service {service_name} has failed to update because resulting "
                        f"service key ({service_key}) does not exist. Update procedure cancelled..."
                    )
            except Exception as e:
                self.log.error(
                    f"Service {service_name} has failed to update. Update procedure cancelled... [{str(e)}]"
                )

            self.container_update.pop(service_name)

    def container_versions(self):
        """Go through the list of services and check what are the latest tags for it"""
        self.scheduler.enter(CONTAINER_CHECK_INTERVAL, 0,
                             self.container_versions)

        for service in self.datastore.list_all_services(full=True):
            if not service.enabled:
                continue

            image_name, tag_name, auth = get_latest_tag_for_service(
                service, self.config, self.log)

            self.latest_service_tags.set(
                service.name, {
                    'auth': auth,
                    'image': image_name,
                    service.update_channel: tag_name
                })

    def try_run(self):
        """Run the scheduler loop until told to stop."""
        # Do an initial call to the main methods, who will then be registered with the scheduler
        self.sync_services()
        self.update_services()
        self.container_versions()
        self.container_updates()
        self.heartbeat()

        # Run as long as we need to
        while self.running:
            delay = self.scheduler.run(False)
            time.sleep(min(delay, 0.1))

    def heartbeat(self):
        """Periodically touch a file on disk.

        Since tasks are run serially, the delay between touches will be the maximum of
        HEARTBEAT_INTERVAL and the longest running task.
        """
        if self.config.logging.heartbeat_file:
            self.scheduler.enter(HEARTBEAT_INTERVAL, 0, self.heartbeat)
            super().heartbeat()

    def update_services(self):
        """Check if we need to update any services.

        Spin off a thread to actually perform any updates. Don't allow multiple threads per service.
        """
        self.scheduler.enter(UPDATE_CHECK_INTERVAL, 0, self.update_services)

        # Check for finished update threads
        self.running_updates = {
            name: thread
            for name, thread in self.running_updates.items()
            if thread.is_alive()
        }

        # Check if its time to try to update the service
        for service_name, data in self.services.items().items():
            if data['next_update'] <= now_as_iso(
            ) and service_name not in self.running_updates:
                self.log.info(f"Time to update {service_name}")
                self.running_updates[service_name] = Thread(
                    target=self.run_update,
                    kwargs=dict(service_name=service_name))
                self.running_updates[service_name].start()

    def run_update(self, service_name):
        """Common setup and tear down for all update types."""
        # noinspection PyBroadException
        try:
            # Check for new update with service specified update method
            service = self.datastore.get_service_with_delta(service_name)
            update_method = service.update_config.method
            update_data = self.services.get(service_name)
            update_hash = None

            try:
                # Actually run the update method
                if update_method == 'run':
                    update_hash = self.do_file_update(
                        service=service,
                        previous_hash=update_data['sha256'],
                        previous_update=update_data['previous_update'])
                elif update_method == 'build':
                    update_hash = self.do_build_update()

                # If we have performed an update, write that data
                if update_hash is not None and update_hash != update_data[
                        'sha256']:
                    update_data['sha256'] = update_hash
                    update_data['previous_update'] = now_as_iso()
                else:
                    update_hash = None

            finally:
                # Update the next service update check time, don't update the config_hash,
                # as we don't want to disrupt being re-run if our config has changed during this run
                update_data['next_update'] = now_as_iso(
                    service.update_config.update_interval_seconds)
                self.services.set(service_name, update_data)

            if update_hash:
                self.log.info(
                    f"New update applied for {service_name}. Restarting service."
                )
                self.controller.restart(service_name=service_name)

        except BaseException:
            self.log.exception(
                "An error occurred while running an update for: " +
                service_name)

    def do_build_update(self):
        """Update a service by building a new container to run."""
        raise NotImplementedError()

    def do_file_update(self, service, previous_hash, previous_update):
        """Update a service by running a container to get new files."""
        temp_directory = tempfile.mkdtemp(dir=self.temporary_directory)
        chmod(temp_directory, 0o777)
        input_directory = os.path.join(temp_directory, 'input_directory')
        output_directory = os.path.join(temp_directory, 'output_directory')
        service_dir = os.path.join(FILE_UPDATE_DIRECTORY, service.name)
        image_variables = defaultdict(str)
        image_variables.update(self.config.services.image_variables)

        try:
            # Use chmod directly to avoid effects of umask
            os.makedirs(input_directory)
            chmod(input_directory, 0o755)
            os.makedirs(output_directory)
            chmod(output_directory, 0o777)

            username = self.ensure_service_account()

            with temporary_api_key(self.datastore, username) as api_key:

                # Write out the parameters we want to pass to the update container
                with open(os.path.join(input_directory, 'config.yaml'),
                          'w') as fh:
                    yaml.safe_dump(
                        {
                            'previous_update':
                            previous_update,
                            'previous_hash':
                            previous_hash,
                            'sources': [
                                x.as_primitives()
                                for x in service.update_config.sources
                            ],
                            'api_user':
                            username,
                            'api_key':
                            api_key,
                            'ui_server':
                            UI_SERVER
                        }, fh)

                # Run the update container
                run_options = service.update_config.run_options
                run_options.image = string.Template(
                    run_options.image).safe_substitute(image_variables)
                self.controller.launch(
                    name=service.name,
                    docker_config=run_options,
                    mounts=[
                        {
                            'volume':
                            FILE_UPDATE_VOLUME,
                            'source_path':
                            os.path.relpath(temp_directory,
                                            start=FILE_UPDATE_DIRECTORY),
                            'dest_path':
                            '/mount/'
                        },
                    ],
                    env={
                        'UPDATE_CONFIGURATION_PATH':
                        '/mount/input_directory/config.yaml',
                        'UPDATE_OUTPUT_PATH': '/mount/output_directory/'
                    },
                    network=f'service-net-{service.name}',
                    blocking=True,
                )

                # Read out the results from the output container
                results_meta_file = os.path.join(output_directory,
                                                 'response.yaml')

                if not os.path.exists(results_meta_file) or not os.path.isfile(
                        results_meta_file):
                    self.log.warning(
                        f"Update produced no output for {service.name}")
                    return None

                with open(results_meta_file) as rf:
                    results_meta = yaml.safe_load(rf)
                update_hash = results_meta.get('hash', None)

                # Erase the results meta file
                os.unlink(results_meta_file)

                # Get a timestamp for now, and switch it to basic format representation of time
                # Still valid iso 8601, and : is sometimes a restricted character
                timestamp = now_as_iso().replace(":", "")

                # FILE_UPDATE_DIRECTORY/{service_name} is the directory mounted to the service,
                # the service sees multiple directories in that directory, each with a timestamp
                destination_dir = os.path.join(service_dir,
                                               service.name + '_' + timestamp)
                shutil.move(output_directory, destination_dir)

                # Remove older update files, due to the naming scheme, older ones will sort first lexically
                existing_folders = []
                for folder_name in os.listdir(service_dir):
                    folder_path = os.path.join(service_dir, folder_name)
                    if os.path.isdir(folder_path) and folder_name.startswith(
                            service.name):
                        existing_folders.append(folder_name)
                existing_folders.sort()

                self.log.info(
                    f'There are {len(existing_folders)} update folders for {service.name} in cache.'
                )
                if len(existing_folders) > UPDATE_FOLDER_LIMIT:
                    extra_count = len(existing_folders) - UPDATE_FOLDER_LIMIT
                    self.log.info(
                        f'We will only keep {UPDATE_FOLDER_LIMIT} updates, deleting {extra_count}.'
                    )
                    for extra_folder in existing_folders[:extra_count]:
                        # noinspection PyBroadException
                        try:
                            shutil.rmtree(
                                os.path.join(service_dir, extra_folder))
                        except Exception:
                            self.log.exception(
                                'Failed to delete update folder')

                return update_hash
        finally:
            # If the working directory is still there for any reason erase it
            shutil.rmtree(temp_directory, ignore_errors=True)

    def ensure_service_account(self):
        """Check that the update service account exists, if it doesn't, create it."""
        uname = 'update_service_account'

        if self.datastore.user.get_if_exists(uname):
            return uname

        user_data = User({
            "agrees_with_tos":
            "NOW",
            "classification":
            "RESTRICTED",
            "name":
            "Update Account",
            "password":
            get_password_hash(''.join(
                random.choices(string.ascii_letters, k=20))),
            "uname":
            uname,
            "type": ["signature_importer"]
        })
        self.datastore.user.save(uname, user_data)
        self.datastore.user_settings.save(uname, UserSettings())
        return uname
Ejemplo n.º 5
0
class DispatchClient:
    def __init__(self, datastore=None, redis=None, redis_persist=None, logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        self.redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger("assemblyline.dispatching.client")
        self.results = self.ds.result
        self.errors = self.ds.error
        self.files = self.ds.file
        self.submission_assignments = ExpiringHash(DISPATCH_TASK_HASH, host=self.redis_persist)
        self.running_tasks = Hash(DISPATCH_RUNNING_TASK_HASH, host=self.redis)
        self.service_data = cast(Dict[str, Service], CachedObject(self._get_services))
        self.dispatcher_data = []
        self.dispatcher_data_age = 0.0
        self.dead_dispatchers = []

    @weak_lru(maxsize=128)
    def _get_queue_from_cache(self, name):
        return NamedQueue(name, host=self.redis, ttl=QUEUE_EXPIRY)

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def is_dispatcher(self, dispatcher_id) -> bool:
        if dispatcher_id in self.dead_dispatchers:
            return False
        if time.time() - self.dispatcher_data_age > 120 or dispatcher_id not in self.dispatcher_data:
            self.dispatcher_data = Dispatcher.all_instances(self.redis_persist)
            self.dispatcher_data_age = time.time()
        if dispatcher_id in self.dispatcher_data:
            return True
        else:
            self.dead_dispatchers.append(dispatcher_id)
            return False

    def dispatch_bundle(self, submission: Submission, results: Dict[str, Result],
                        file_infos: Dict[str, File], file_tree, errors: Dict[str, Error], completed_queue: str = None):
        """Insert a bundle into the dispatching system and continue scanning of its files

        Prerequisites:
            - Submission, results, file_infos and errors should already be saved in the datastore
            - Files should already be in the filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            results=results,
            file_infos=file_infos,
            file_tree=file_tree,
            errors=errors,
            completed_queue=completed_queue,
        ))

    def dispatch_submission(self, submission: Submission, completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(dict(
            submission=submission.as_primitives(),
            completed_queue=completed_queue,
        ))

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="ResponseQueue")
            queue = NamedQueue(queue_name, host=self.redis, ttl=30)
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, ttl=QUEUE_EXPIRY, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': LIST_OUTSTANDING,
                'payload_data': ListOutstanding({
                    'response_queue': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue.pop(timeout=30)
        return {}

    @elasticapm.capture_span(span_type='dispatch_client')
    def request_work(self, worker_id, service_name, service_version,
                     timeout: float = 60, blocking=True, low_priority=False) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            work = self._request_work(worker_id, service_name, service_version,
                                      blocking=blocking, timeout=remaining, low_priority=low_priority)
            if work or not blocking:
                return work
            remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version,
                      timeout, blocking, low_priority=False) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout), low_priority=low_priority)
        else:
            if low_priority:
                result = work_queue.unpush(1)
            else:
                result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(f"{service_name}:{worker_id} no task returned: [empty message]")
            return None
        task = ServiceTask(result)
        task.metadata['worker__'] = worker_id
        dispatcher = task.metadata['dispatcher__']

        if not self.is_dispatcher(dispatcher):
            self.log.info(f"{service_name}:{worker_id} no task returned: [task from dead dispatcher]")
            return None

        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found")
            start_queue = self._get_queue_from_cache(DISPATCH_START_EVENTS + dispatcher)
            start_queue.push((task.sid, task.fileinfo.sha256, service_name, worker_id))
            return task
        return None

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_finished(self, sid: str, result_key: str, result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(sid=sid, service_name=result.response.service_name, sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing successful results.")
            return
        task = ServiceTask(task)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key, {"expiry_ts": result.archive_ts})
        else:
            while True:
                old, version = self.ds.result.get_if_exists(
                    result_key, archive_access=self.config.datastore.ilm.update_archive, version=True)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                try:
                    self.ds.result.save(result_key, result, version=version)
                    break
                except VersionConflictException as vce:
                    self.log.info(f"Retrying to save results due to version conflict: {str(vce)}")

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w, host=self.redis).push(msg)

        # Save the tags
        tags = []
        for section in result.result.sections:
            tags.extend(tag_dict_to_list(flatten(section.tags.as_primitives())))

        # Pull out file names if we have them
        file_names = {}
        for extracted_data in result.response.extracted:
            if extracted_data.name:
                file_names[extracted_data.sha256] = extracted_data.name

        #
        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        ex_ts = result.expiry_ts.strftime(DATEFORMAT) if result.expiry_ts else result.archive_ts.strftime(DATEFORMAT)
        result_queue.push({
            # 'service_task': task.as_primitives(),
            # 'result': result.as_primitives(),
            'sid': task.sid,
            'sha256': result.sha256,
            'service_name': task.service_name,
            'service_version': result.response.service_version,
            'service_tool_version': result.response.service_tool_version,
            'archive_ts': result.archive_ts.strftime(DATEFORMAT),
            'expiry_ts': ex_ts,
            'result_summary': {
                'key': result_key,
                'drop': result.drop_file,
                'score': result.result.score,
                'children': [r.sha256 for r in result.response.extracted],
            },
            'tags': tags,
            'extracted_names': file_names,
            'temporary_data': temporary_data
        })

    @elasticapm.capture_span(span_type='dispatch_client')
    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(sid=sid, service_name=error.response.service_name, sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                             f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error.")
        if error.response.status == "FAIL_NONRECOVERABLE":
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w, host=self.redis).push(msg)

        dispatcher = task.metadata['dispatcher__']
        result_queue = self._get_queue_from_cache(DISPATCH_RESULT_QUEUE + dispatcher)
        result_queue.push({
            'sid': task.sid,
            'service_task': task.as_primitives(),
            'error': error.as_primitives(),
            'error_key': error_key
        })

    def setup_watch_queue(self, sid: str) -> Optional[str]:
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        dispatcher_id = self.submission_assignments.get(sid)
        if dispatcher_id:
            queue_name = reply_queue_name(prefix="D", suffix="WQ")
            command_queue = NamedQueue(DISPATCH_COMMAND_QUEUE+dispatcher_id, host=self.redis)
            command_queue.push(DispatcherCommandMessage({
                'kind': CREATE_WATCH,
                'payload_data': CreateWatch({
                    'queue_name': queue_name,
                    'submission': sid
                })
            }).as_primitives())
            return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
class Ingester(ThreadedCoreBase):
    def __init__(self,
                 datastore=None,
                 logger=None,
                 classification=None,
                 redis=None,
                 persistent_redis=None,
                 metrics_name='ingester',
                 config=None):
        super().__init__('assemblyline.ingester',
                         logger,
                         redis=redis,
                         redis_persist=persistent_redis,
                         datastore=datastore,
                         config=config)

        # Cache the user groups
        self.cache_lock = threading.RLock()
        self._user_groups = {}
        self._user_groups_reset = time.time() // HOUR_IN_SECONDS
        self.cache = {}
        self.notification_queues = {}
        self.whitelisted = {}
        self.whitelisted_lock = threading.RLock()

        # Module path parameters are fixed at start time. Changing these involves a restart
        self.is_low_priority = load_module_by_path(
            self.config.core.ingester.is_low_priority)
        self.get_whitelist_verdict = load_module_by_path(
            self.config.core.ingester.get_whitelist_verdict)
        self.whitelist = load_module_by_path(
            self.config.core.ingester.whitelist)

        # Constants are loaded based on a non-constant path, so has to be done at init rather than load
        constants = forge.get_constants(self.config)
        self.priority_value: dict[str, int] = constants.PRIORITIES
        self.priority_range: dict[str, Tuple[int,
                                             int]] = constants.PRIORITY_RANGES
        self.threshold_value: dict[str, int] = constants.PRIORITY_THRESHOLDS

        # Classification engine
        self.ce = classification or forge.get_classification()

        # Metrics gathering factory
        self.counter = MetricsFactory(metrics_type='ingester',
                                      schema=Metrics,
                                      redis=self.redis,
                                      config=self.config,
                                      name=metrics_name)

        # State. The submissions in progress are stored in Redis in order to
        # persist this state and recover in case we crash.
        self.scanning = Hash('m-scanning-table', self.redis_persist)

        # Input. The dispatcher creates a record when any submission completes.
        self.complete_queue = NamedQueue(COMPLETE_QUEUE_NAME, self.redis)

        # Input. An external process places submission requests on this queue.
        self.ingest_queue = NamedQueue(INGEST_QUEUE_NAME, self.redis_persist)

        # Output. Duplicate our input traffic into this queue so it may be cloned by other systems
        self.traffic_queue = CommsQueue('submissions', self.redis)

        # Internal. Unique requests are placed in and processed from this queue.
        self.unique_queue = PriorityQueue('m-unique', self.redis_persist)

        # Internal, delay queue for retrying
        self.retry_queue = PriorityQueue('m-retry', self.redis_persist)

        # Internal, timeout watch queue
        self.timeout_queue: PriorityQueue[str] = PriorityQueue(
            'm-timeout', self.redis)

        # Internal, queue for processing duplicates
        #   When a duplicate file is detected (same cache key => same file, and same
        #   submission parameters) the file won't be ingested normally, but instead a reference
        #   will be written to a duplicate queue. Whenever a file is finished, in the complete
        #   method, not only is the original ingestion finalized, but all entries in the duplicate queue
        #   are finalized as well. This has the effect that all concurrent ingestion of the same file
        #   are 'merged' into a single submission to the system.
        self.duplicate_queue = MultiQueue(self.redis_persist)

        # Output. submissions that should have alerts generated
        self.alert_queue = NamedQueue(ALERT_QUEUE_NAME, self.redis_persist)

        # Utility object to help submit tasks to dispatching
        self.submit_client = SubmissionClient(datastore=self.datastore,
                                              redis=self.redis)

        if self.config.core.metrics.apm_server.server_url is not None:
            self.log.info(
                f"Exporting application metrics to: {self.config.core.metrics.apm_server.server_url}"
            )
            elasticapm.instrument()
            self.apm_client = elasticapm.Client(
                server_url=self.config.core.metrics.apm_server.server_url,
                service_name="ingester")
        else:
            self.apm_client = None

    def try_run(self):
        threads_to_maintain = {
            'Retries': self.handle_retries,
            'Timeouts': self.handle_timeouts
        }
        threads_to_maintain.update({
            f'Complete_{n}': self.handle_complete
            for n in range(COMPLETE_THREADS)
        })
        threads_to_maintain.update(
            {f'Ingest_{n}': self.handle_ingest
             for n in range(INGEST_THREADS)})
        threads_to_maintain.update(
            {f'Submit_{n}': self.handle_submit
             for n in range(SUBMIT_THREADS)})
        self.maintain_threads(threads_to_maintain)

    def handle_ingest(self):
        cpu_mark = time.process_time()
        time_mark = time.time()

        # Move from ingest to unique and waiting queues.
        # While there are entries in the ingest queue we consume chunk_size
        # entries at a time and move unique entries to uniqueq / queued and
        # duplicates to their own queues / waiting.
        while self.running:
            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

            message = self.ingest_queue.pop(timeout=1)

            cpu_mark = time.process_time()
            time_mark = time.time()

            if not message:
                continue

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            try:
                if 'submission' in message:
                    # A retried task
                    task = IngestTask(message)
                else:
                    # A new submission
                    sub = MessageSubmission(message)
                    task = IngestTask(dict(
                        submission=sub,
                        ingest_id=sub.sid,
                    ))
                    task.submission.sid = None  # Reset to new random uuid
                    # Write all input to the traffic queue
                    self.traffic_queue.publish(
                        SubmissionMessage({
                            'msg': sub,
                            'msg_type': 'SubmissionIngested',
                            'sender': 'ingester',
                        }).as_primitives())

            except (ValueError, TypeError) as error:
                self.counter.increment('error')
                self.log.exception(
                    f"Dropped ingest submission {message} because {str(error)}"
                )

                # End of ingest message (value_error)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_input',
                                                    'value_error')
                continue

            self.ingest(task)

            # End of ingest message (success)
            if self.apm_client:
                self.apm_client.end_transaction('ingest_input', 'success')

    def handle_submit(self):
        time_mark, cpu_mark = time.time(), time.process_time()

        while self.running:
            # noinspection PyBroadException
            try:
                self.counter.increment_execution_time(
                    'cpu_seconds',
                    time.process_time() - cpu_mark)
                self.counter.increment_execution_time('busy_seconds',
                                                      time.time() - time_mark)

                # Check if there is room for more submissions
                length = self.scanning.length()
                if length >= self.config.core.ingester.max_inflight:
                    self.sleep(0.1)
                    time_mark, cpu_mark = time.time(), time.process_time()
                    continue

                raw = self.unique_queue.blocking_pop(timeout=3)
                time_mark, cpu_mark = time.time(), time.process_time()
                if not raw:
                    continue

                # Start of ingest message
                if self.apm_client:
                    self.apm_client.begin_transaction('ingest_msg')

                task = IngestTask(raw)

                # Check if we need to drop a file for capacity reasons, but only if the
                # number of files in flight is alreay over 80%
                if length >= self.config.core.ingester.max_inflight * 0.8 and self.drop(
                        task):
                    # End of ingest message (dropped)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'dropped')
                    continue

                if self.is_whitelisted(task):
                    # End of ingest message (whitelisted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'whitelisted')
                    continue

                # Check if this file has been previously processed.
                pprevious, previous, score, scan_key = None, None, None, None
                if not task.submission.params.ignore_cache:
                    pprevious, previous, score, scan_key = self.check(task)
                else:
                    scan_key = self.stamp_filescore_key(task)

                # If it HAS been previously processed, we are dealing with a resubmission
                # finalize will decide what to do, and put the task back in the queue
                # rewritten properly if we are going to run it again
                if previous:
                    if not task.submission.params.services.resubmit and not pprevious:
                        self.log.warning(
                            f"No psid for what looks like a resubmission of "
                            f"{task.submission.files[0].sha256}: {scan_key}")
                    self.finalize(pprevious, previous, score, task)
                    # End of ingest message (finalized)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'finalized')

                    continue

                # We have decided this file is worth processing

                # Add the task to the scanning table, this is atomic across all submit
                # workers, so if it fails, someone beat us to the punch, record the file
                # as a duplicate then.
                if not self.scanning.add(scan_key, task.as_primitives()):
                    self.log.debug('Duplicate %s',
                                   task.submission.files[0].sha256)
                    self.counter.increment('duplicates')
                    self.duplicate_queue.push(_dup_prefix + scan_key,
                                              task.as_primitives())
                    # End of ingest message (duplicate)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'duplicate')

                    continue

                # We have managed to add the task to the scan table, so now we go
                # ahead with the submission process
                try:
                    self.submit(task)
                    # End of ingest message (submitted)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'submitted')

                    continue
                except Exception as _ex:
                    # For some reason (contained in `ex`) we have failed the submission
                    # The rest of this function is error handling/recovery
                    ex = _ex
                    # traceback = _ex.__traceback__

                self.counter.increment('error')

                should_retry = True
                if isinstance(ex, CorruptedFileStoreException):
                    self.log.error(
                        "Submission for file '%s' failed due to corrupted "
                        "filestore: %s" % (task.sha256, str(ex)))
                    should_retry = False
                elif isinstance(ex, DataStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed due to "
                                   "data store error:\n%s" %
                                   (task.sha256, trace))
                elif not isinstance(ex, FileStoreException):
                    trace = exceptions.get_stacktrace_info(ex)
                    self.log.error("Submission for file '%s' failed: %s" %
                                   (task.sha256, trace))

                task = IngestTask(self.scanning.pop(scan_key))
                if not task:
                    self.log.error('No scanning entry for for %s', task.sha256)
                    # End of ingest message (no_scan_entry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'no_scan_entry')

                    continue

                if not should_retry:
                    # End of ingest message (cannot_retry)
                    if self.apm_client:
                        self.apm_client.end_transaction(
                            'ingest_submit', 'cannot_retry')

                    continue

                self.retry(task, scan_key, ex)
                # End of ingest message (retry)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit', 'retried')

            except Exception:
                self.log.exception("Unexpected error")
                # End of ingest message (exception)
                if self.apm_client:
                    self.apm_client.end_transaction('ingest_submit',
                                                    'exception')

    def handle_complete(self):
        while self.running:
            result = self.complete_queue.pop(timeout=3)
            if not result:
                continue

            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_msg')

            sub = DatabaseSubmission(result)
            self.completed(sub)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(sid=sub.sid)
                self.apm_client.end_transaction('ingest_complete', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_retries(self):
        tasks = []
        while self.sleep(0 if tasks else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_retries')

            tasks = self.retry_queue.dequeue_range(upper_limit=isotime.now(),
                                                   num=100)

            for task in tasks:
                self.ingest_queue.push(task)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(retries=len(tasks))
                self.apm_client.end_transaction('ingest_retries', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def handle_timeouts(self):
        timeouts = []
        while self.sleep(0 if timeouts else 3):
            cpu_mark = time.process_time()
            time_mark = time.time()

            # Start of ingest message
            if self.apm_client:
                self.apm_client.begin_transaction('ingest_timeouts')

            timeouts = self.timeout_queue.dequeue_range(
                upper_limit=isotime.now(), num=100)

            for scan_key in timeouts:
                # noinspection PyBroadException
                try:
                    actual_timeout = False

                    # Remove the entry from the hash of submissions in progress.
                    entry = self.scanning.pop(scan_key)
                    if entry:
                        actual_timeout = True
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(entry))

                    dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                   blocking=False)
                    if dup:
                        actual_timeout = True

                    while dup:
                        self.log.error("Submission timed out for %s: %s",
                                       scan_key, str(dup))
                        dup = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                                       blocking=False)

                    if actual_timeout:
                        self.counter.increment('timed_out')
                except Exception:
                    self.log.exception("Problem timing out %s:", scan_key)

            # End of ingest message (success)
            if self.apm_client:
                elasticapm.label(timeouts=len(timeouts))
                self.apm_client.end_transaction('ingest_timeouts', 'success')

            self.counter.increment_execution_time(
                'cpu_seconds',
                time.process_time() - cpu_mark)
            self.counter.increment_execution_time('busy_seconds',
                                                  time.time() - time_mark)

    def get_groups_from_user(self, username: str) -> List[str]:
        # Reset the group cache at the top of each hour
        if time.time() // HOUR_IN_SECONDS > self._user_groups_reset:
            self._user_groups = {}
            self._user_groups_reset = time.time() // HOUR_IN_SECONDS

        # Get the groups for this user if not known
        if username not in self._user_groups:
            user_data = self.datastore.user.get(username)
            if user_data:
                self._user_groups[username] = user_data.groups
            else:
                self._user_groups[username] = []
        return self._user_groups[username]

    def ingest(self, task: IngestTask):
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Task received for processing"
        )
        # Load a snapshot of ingest parameters as of right now.
        max_file_size = self.config.submission.max_file_size
        param = task.params

        self.counter.increment('bytes_ingested', increment_by=task.file_size)
        self.counter.increment('submissions_ingested')

        if any(len(file.sha256) != 64 for file in task.submission.files):
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] Invalid sha256, skipped")
            self.send_notification(task,
                                   failure="Invalid sha256",
                                   logfunc=self.log.warning)
            return

        # Clean up metadata strings, since we may delete some, iterate on a copy of the keys
        for key in list(task.submission.metadata.keys()):
            value = task.submission.metadata[key]
            meta_size = len(value)
            if meta_size > self.config.submission.max_metadata_length:
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] '
                    f'Removing {key} from metadata because value is too big')
                task.submission.metadata.pop(key)

        if task.file_size > max_file_size and not task.params.ignore_size and not task.params.never_drop:
            task.failure = f"File too large ({task.file_size} > {max_file_size})"
            self._notify_drop(task)
            self.counter.increment('skipped')
            self.log.error(
                f"[{task.ingest_id} :: {task.sha256}] {task.failure}")
            return

        # Set the groups from the user, if they aren't already set
        if not task.params.groups:
            task.params.groups = self.get_groups_from_user(
                task.params.submitter)

        # Check if this file is already being processed
        self.stamp_filescore_key(task)
        pprevious, previous, score = None, None, None
        if not param.ignore_cache:
            pprevious, previous, score, _ = self.check(task, count_miss=False)

        # Assign priority.
        low_priority = self.is_low_priority(task)

        priority = param.priority
        if priority < 0:
            priority = self.priority_value['medium']

            if score is not None:
                priority = self.priority_value['low']
                for level, threshold in self.threshold_value.items():
                    if score >= threshold:
                        priority = self.priority_value[level]
                        break
            elif low_priority:
                priority = self.priority_value['low']

        # Reduce the priority by an order of magnitude for very old files.
        current_time = now()
        if priority and self.expired(
                current_time - task.submission.time.timestamp(), 0):
            priority = (priority / 10) or 1

        param.priority = priority

        # Do this after priority has been assigned.
        # (So we don't end up dropping the resubmission).
        if previous:
            self.counter.increment('duplicates')
            self.finalize(pprevious, previous, score, task)

            # On cache hits of any kind we want to send out a completed message
            self.traffic_queue.publish(
                SubmissionMessage({
                    'msg': task.submission,
                    'msg_type': 'SubmissionCompleted',
                    'sender': 'ingester',
                }).as_primitives())
            return

        if self.drop(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Dropped")
            return

        if self.is_whitelisted(task):
            self.log.info(f"[{task.ingest_id} :: {task.sha256}] Whitelisted")
            return

        self.unique_queue.push(priority, task.as_primitives())

    def check(
        self,
        task: IngestTask,
        count_miss=True
    ) -> Tuple[Optional[str], Optional[str], Optional[float], str]:
        key = self.stamp_filescore_key(task)

        with self.cache_lock:
            result = self.cache.get(key, None)

        if result:
            self.counter.increment('cache_hit_local')
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Local cache hit')
        else:
            result = self.datastore.filescore.get_if_exists(key)
            if result:
                self.counter.increment('cache_hit')
                self.log.info(
                    f'[{task.ingest_id} :: {task.sha256}] Remote cache hit')
            else:
                if count_miss:
                    self.counter.increment('cache_miss')
                return None, None, None, key

            with self.cache_lock:
                self.cache[key] = result

        current_time = now()
        age = current_time - result.time
        errors = result.errors

        if self.expired(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache has expired"
            )
            self.counter.increment('cache_expired')
            self.cache.pop(key, None)
            self.datastore.filescore.delete(key)
            return None, None, None, key
        elif self.stale(age, errors):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Cache hit dropped, cache is stale"
            )
            self.counter.increment('cache_stale')
            return None, None, result.score, key

        return result.psid, result.sid, result.score, key

    def stop(self):
        super().stop()
        if self.apm_client:
            elasticapm.uninstrument()
        self.submit_client.stop()

    def stale(self, delta: float, errors: int):
        if errors:
            return delta >= self.config.core.ingester.incomplete_stale_after_seconds
        else:
            return delta >= self.config.core.ingester.stale_after_seconds

    @staticmethod
    def stamp_filescore_key(task: IngestTask, sha256: str = None) -> str:
        if not sha256:
            sha256 = task.submission.files[0].sha256

        key = task.submission.scan_key

        if not key:
            key = task.params.create_filescore_key(sha256)
            task.submission.scan_key = key

        return key

    def completed(self, sub: DatabaseSubmission):
        """Invoked when notified that a submission has completed."""
        # There is only one file in the submissions we have made
        sha256 = sub.files[0].sha256
        scan_key = sub.scan_key
        if not scan_key:
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission missing scan key")
            scan_key = sub.params.create_filescore_key(sha256)
        raw = self.scanning.pop(scan_key)

        psid = sub.params.psid
        score = sub.max_score
        sid = sub.sid

        if not raw:
            # Some other worker has already popped the scanning queue?
            self.log.warning(
                f"[{sub.metadata.get('ingest_id', 'unknown')} :: {sha256}] "
                f"Submission completed twice")
            return scan_key

        task = IngestTask(raw)
        task.submission.sid = sid

        errors = sub.error_count
        file_count = sub.file_count
        self.counter.increment('submissions_completed')
        self.counter.increment('files_completed', increment_by=file_count)
        self.counter.increment('bytes_completed', increment_by=task.file_size)

        with self.cache_lock:
            fs = self.cache[scan_key] = FileScore({
                'expiry_ts':
                now(self.config.core.ingester.cache_dtl * 24 * 60 * 60),
                'errors':
                errors,
                'psid':
                psid,
                'score':
                score,
                'sid':
                sid,
                'time':
                now(),
            })
        self.datastore.filescore.save(scan_key, fs)

        self.finalize(psid, sid, score, task)

        def exhaust() -> Iterable[IngestTask]:
            while True:
                res = self.duplicate_queue.pop(_dup_prefix + scan_key,
                                               blocking=False)
                if res is None:
                    break
                res = IngestTask(res)
                res.submission.sid = sid
                yield res

        # You may be tempted to remove the assignment to dups and use the
        # value directly in the for loop below. That would be a mistake.
        # The function finalize may push on the duplicate queue which we
        # are pulling off and so condensing those two lines creates a
        # potential infinite loop.
        dups = [dup for dup in exhaust()]
        for dup in dups:
            self.finalize(psid, sid, score, dup)

        return scan_key

    def send_notification(self, task: IngestTask, failure=None, logfunc=None):
        if logfunc is None:
            logfunc = self.log.info

        if failure:
            task.failure = failure

        failure = task.failure
        if failure:
            logfunc("%s: %s", failure, str(task.json()))

        if not task.submission.notification.queue:
            return

        note_queue = _notification_queue_prefix + task.submission.notification.queue
        threshold = task.submission.notification.threshold

        if threshold is not None and task.score is not None and task.score < threshold:
            return

        q = self.notification_queues.get(note_queue, None)
        if not q:
            self.notification_queues[note_queue] = q = NamedQueue(
                note_queue, self.redis_persist)
        q.push(task.as_primitives())

    def expired(self, delta: float, errors) -> bool:
        if errors:
            return delta >= self.config.core.ingester.incomplete_expire_after_seconds
        else:
            return delta >= self.config.core.ingester.expire_after

    def drop(self, task: IngestTask) -> bool:
        priority = task.params.priority
        sample_threshold = self.config.core.ingester.sampling_at

        dropped = False
        if priority <= _min_priority:
            dropped = True
        else:
            for level, rng in self.priority_range.items():
                if rng[0] <= priority <= rng[1] and level in sample_threshold:
                    dropped = must_drop(self.unique_queue.count(*rng),
                                        sample_threshold[level])
                    break

            if not dropped:
                if task.file_size > self.config.submission.max_file_size or task.file_size == 0:
                    dropped = True

        if task.params.never_drop or not dropped:
            return False

        task.failure = 'Skipped'
        self._notify_drop(task)
        self.counter.increment('skipped')
        return True

    def _notify_drop(self, task: IngestTask):
        self.send_notification(task)

        c12n = task.params.classification
        expiry = now_as_iso(86400)
        sha256 = task.submission.files[0].sha256

        self.datastore.save_or_freshen_file(sha256, {'sha256': sha256},
                                            expiry,
                                            c12n,
                                            redis=self.redis)

    def is_whitelisted(self, task: IngestTask):
        reason, hit = self.get_whitelist_verdict(self.whitelist, task)
        hit = {x: dotdump(safe_str(y)) for x, y in hit.items()}
        sha256 = task.submission.files[0].sha256

        if not reason:
            with self.whitelisted_lock:
                reason = self.whitelisted.get(sha256, None)
                if reason:
                    hit = 'cached'

        if reason:
            if hit != 'cached':
                with self.whitelisted_lock:
                    self.whitelisted[sha256] = reason

            task.failure = "Whitelisting due to reason %s (%s)" % (dotdump(
                safe_str(reason)), hit)
            self._notify_drop(task)

            self.counter.increment('whitelisted')

        return reason

    def submit(self, task: IngestTask):
        self.submit_client.submit(
            submission_obj=task.submission,
            completed_queue=COMPLETE_QUEUE_NAME,
        )

        self.timeout_queue.push(int(now(_max_time)), task.submission.scan_key)
        self.log.info(
            f"[{task.ingest_id} :: {task.sha256}] Submitted to dispatcher for analysis"
        )

    def retry(self, task: IngestTask, scan_key: str, ex):
        current_time = now()

        retries = task.retries + 1

        if retries > _max_retries:
            trace = ''
            if ex:
                trace = ': ' + get_stacktrace_info(ex)
            self.log.error(
                f'[{task.ingest_id} :: {task.sha256}] Max retries exceeded {trace}'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        elif self.expired(current_time - task.ingest_time.timestamp(), 0):
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] No point retrying expired submission'
            )
            self.duplicate_queue.delete(_dup_prefix + scan_key)
        else:
            self.log.info(
                f'[{task.ingest_id} :: {task.sha256}] Requeuing ({ex or "unknown"})'
            )
            task.retries = retries
            self.retry_queue.push(int(now(_retry_delay)), task.as_primitives())

    def finalize(self, psid: str, sid: str, score: float, task: IngestTask):
        self.log.info(f"[{task.ingest_id} :: {task.sha256}] Completed")
        if psid:
            task.params.psid = psid
        task.score = score
        task.submission.sid = sid

        selected = task.params.services.selected
        resubmit_to = task.params.services.resubmit

        resubmit_selected = determine_resubmit_selected(selected, resubmit_to)
        will_resubmit = resubmit_selected and should_resubmit(score)
        if will_resubmit:
            task.extended_scan = 'submitted'
            task.params.psid = None

        if self.is_alert(task, score):
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Notifying alerter "
                f"to {'update' if task.params.psid else 'create'} an alert")
            self.alert_queue.push(task.as_primitives())

        self.send_notification(task)

        if will_resubmit:
            self.log.info(
                f"[{task.ingest_id} :: {task.sha256}] Resubmitted for extended analysis"
            )
            task.params.psid = sid
            task.submission.sid = None
            task.submission.scan_key = None
            task.params.services.resubmit = []
            task.params.services.selected = resubmit_selected

            self.unique_queue.push(task.params.priority, task.as_primitives())

    def is_alert(self, task: IngestTask, score: float) -> bool:
        if not task.params.generate_alert:
            return False

        if score < self.config.core.alerter.threshold:
            return False

        return True