예제 #1
0
class WatcherClient:
    def __init__(self, redis_persist=None):
        config = forge.get_config()

        self.redis = redis_persist or get_client(
            host=config.core.redis.persistent.host,
            port=config.core.redis.persistent.port,
            private=False,
        )
        self.hash = ExpiringHash(name=WATCHER_HASH,
                                 ttl=MAX_TIMEOUT,
                                 host=redis_persist)
        self.queue = UniquePriorityQueue(WATCHER_QUEUE, redis_persist)

    def touch(self, timeout: int, key: str, queue: str, message: dict):
        if timeout >= MAX_TIMEOUT:
            raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}")
        self.hash.set(key, {
            'action': WatcherAction.Message,
            'queue': queue,
            'message': message
        })
        seconds, _ = retry_call(self.redis.time)
        self.queue.push(int(seconds + timeout), key)

    def touch_task(self, timeout: int, key: str, worker: str, task_key: str):
        if timeout >= MAX_TIMEOUT:
            raise ValueError(f"Can't set watcher timeouts over {MAX_TIMEOUT}")
        self.hash.set(
            key, {
                'action': WatcherAction.TimeoutTask,
                'worker': worker,
                'task_key': task_key
            })
        seconds, _ = retry_call(self.redis.time)
        self.queue.push(int(seconds + timeout), key)

    def clear(self, key: str):
        self.queue.remove(key)
        self.hash.pop(key)
예제 #2
0
class TaskingClient:
    """A helper class to simplify tasking for privileged services and service-server.

    This tool helps take care of interactions between the filestore,
    datastore, dispatcher, and any sources of files to be processed.
    """
    def __init__(self,
                 datastore: AssemblylineDatastore = None,
                 filestore: FileStore = None,
                 config=None,
                 redis=None,
                 redis_persist=None,
                 identify=None):
        self.log = logging.getLogger('assemblyline.tasking_client')
        self.config = config or forge.CachedObject(forge.get_config)
        self.datastore = datastore or forge.get_datastore(self.config)
        self.dispatch_client = DispatchClient(self.datastore,
                                              redis=redis,
                                              redis_persist=redis_persist)
        self.event_sender = EventSender('changes.services', redis)
        self.filestore = filestore or forge.get_filestore(self.config)
        self.heuristic_handler = HeuristicHandler(self.datastore)
        self.heuristics = {
            h.heur_id: h
            for h in self.datastore.list_all_heuristics()
        }
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         ttl=60 * 30,
                                         host=redis)
        self.tag_safelister = forge.CachedObject(forge.get_tag_safelister,
                                                 kwargs=dict(
                                                     log=self.log,
                                                     config=config,
                                                     datastore=self.datastore),
                                                 refresh=300)
        if identify:
            self.cleanup = False
        else:
            self.cleanup = True
        self.identify = identify or forge.get_identify(
            config=self.config, datastore=self.datastore, use_cache=True)

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.stop()

    def stop(self):
        if self.cleanup:
            self.identify.stop()

    @elasticapm.capture_span(span_type='tasking_client')
    def upload_file(self,
                    file_path,
                    classification,
                    ttl,
                    is_section_image,
                    expected_sha256=None):
        # Identify the file info of the uploaded file
        file_info = self.identify.fileinfo(file_path)

        # Validate SHA256 of the uploaded file
        if expected_sha256 is None or expected_sha256 == file_info['sha256']:
            file_info['archive_ts'] = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            file_info['classification'] = classification
            if ttl:
                file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60)
            else:
                file_info['expiry_ts'] = None

            # Update the datastore with the uploaded file
            self.datastore.save_or_freshen_file(
                file_info['sha256'],
                file_info,
                file_info['expiry_ts'],
                file_info['classification'],
                is_section_image=is_section_image)

            # Upload file to the filestore (upload already checks if the file exists)
            self.filestore.upload(file_path, file_info['sha256'])
        else:
            raise TaskingClientException(
                "Uploaded file does not match expected file hash. "
                f"[{file_info['sha256']} != {expected_sha256}]")

    # Service
    @elasticapm.capture_span(span_type='tasking_client')
    def register_service(self, service_data, log_prefix=""):
        keep_alive = True

        try:
            # Get heuristics list
            heuristics = service_data.pop('heuristics', None)

            # Patch update_channel, registry_type before Service registration object creation
            service_data['update_channel'] = service_data.get(
                'update_channel',
                self.config.services.preferred_update_channel)
            service_data['docker_config']['registry_type'] = service_data['docker_config'] \
                .get('registry_type', self.config.services.preferred_registry_type)
            service_data['privileged'] = service_data.get(
                'privileged', self.config.services.prefer_service_privileged)
            for dep in service_data.get('dependencies', {}).values():
                dep['container']['registry_type'] = dep.get(
                    'registry_type',
                    self.config.services.preferred_registry_type)

            # Pop unused registration service_data
            for x in ['file_required', 'tool_version']:
                service_data.pop(x, None)

            # Create Service registration object
            service = Service(service_data)

            # Fix service version, we don't need to see the stable label
            service.version = service.version.replace('stable', '')

            # Save service if it doesn't already exist
            if not self.datastore.service.exists(
                    f'{service.name}_{service.version}'):
                self.datastore.service.save(
                    f'{service.name}_{service.version}', service)
                self.datastore.service.commit()
                self.log.info(f"{log_prefix}{service.name} registered")
                keep_alive = False

            # Save service delta if it doesn't already exist
            if not self.datastore.service_delta.exists(service.name):
                self.datastore.service_delta.save(service.name,
                                                  {'version': service.version})
                self.datastore.service_delta.commit()
                self.log.info(f"{log_prefix}{service.name} "
                              f"version ({service.version}) registered")

            new_heuristics = []
            if heuristics:
                plan = self.datastore.heuristic.get_bulk_plan()
                for index, heuristic in enumerate(heuristics):
                    heuristic_id = f'#{index}'  # Set heuristic id to it's position in the list for logging purposes
                    try:
                        # Append service name to heuristic ID
                        heuristic[
                            'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}"

                        # Attack_id field is now a list, make it a list if we receive otherwise
                        attack_id = heuristic.get('attack_id', None)
                        if isinstance(attack_id, str):
                            heuristic['attack_id'] = [attack_id]

                        heuristic = Heuristic(heuristic)
                        heuristic_id = heuristic.heur_id
                        plan.add_upsert_operation(heuristic_id, heuristic)
                    except Exception as e:
                        msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}"
                        self.log.exception(f"{log_prefix}{msg}")
                        raise ValueError(msg)

                for item in self.datastore.heuristic.bulk(plan)['items']:
                    if item['update']['result'] != "noop":
                        new_heuristics.append(item['update']['_id'])
                        self.log.info(
                            f"{log_prefix}{service.name} "
                            f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}"
                        )

                self.datastore.heuristic.commit()

            service_config = self.datastore.get_service_with_delta(
                service.name, as_obj=False)

            # Notify components watching for service config changes
            self.event_sender.send(service.name, {
                'operation': Operation.Added,
                'name': service.name
            })

        except ValueError as e:  # Catch errors when building Service or Heuristic model(s)
            raise e

        return dict(keep_alive=keep_alive,
                    new_heuristics=new_heuristics,
                    service_config=service_config or dict())

    # Task
    @elasticapm.capture_span(span_type='tasking_client')
    def get_task(self,
                 client_id,
                 service_name,
                 service_version,
                 service_tool_version,
                 metric_factory,
                 status_expiry=None,
                 timeout=30):
        if status_expiry is None:
            status_expiry = time.time() + timeout

        cache_found = False

        try:
            service_data = self.dispatch_client.service_data[service_name]
        except KeyError:
            raise ServiceMissingException(
                "The service you're asking task for does not exist, try later",
                404)

        # Set the service status to Idle since we will be waiting for a task
        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, status_expiry))

        # Getting a new task
        task = self.dispatch_client.request_work(client_id,
                                                 service_name,
                                                 service_version,
                                                 timeout=timeout)

        if not task:
            # We've reached the timeout and no task found in service queue
            return None, False

        # We've got a task to process, consider us busy
        self.status_table.set(client_id, (service_name, ServiceStatus.Running,
                                          time.time() + service_data.timeout))
        metric_factory.increment('execute')

        result_key = Result.help_build_key(
            sha256=task.fileinfo.sha256,
            service_name=service_name,
            service_version=service_version,
            service_tool_version=service_tool_version,
            is_empty=False,
            task=task)

        # If we are allowed, try to see if the result has been cached
        if not task.ignore_cache and not service_data.disable_cache:
            # Checking for previous results for this key
            result = self.datastore.result.get_if_exists(result_key)
            if result:
                metric_factory.increment('cache_hit')
                if result.result.score:
                    metric_factory.increment('scored')
                else:
                    metric_factory.increment('not_scored')

                result.archive_ts = now_as_iso(
                    self.config.datastore.ilm.days_until_archive * 24 * 60 *
                    60)
                if task.ttl:
                    result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)
                cache_found = True

            if not cache_found:
                # Checking for previous empty results for this key
                result = self.datastore.emptyresult.get_if_exists(
                    f"{result_key}.e")
                if result:
                    metric_factory.increment('cache_hit')
                    metric_factory.increment('not_scored')
                    result = self.datastore.create_empty_result_from_key(
                        result_key)
                    self.dispatch_client.service_finished(
                        task.sid, f"{result_key}.e", result)
                    cache_found = True

            if not cache_found:
                metric_factory.increment('cache_miss')

        else:
            metric_factory.increment('cache_skipped')

        if not cache_found:
            # No luck with the cache, lets dispatch the task to a client
            return task.as_primitives(), False

        return None, True

    @elasticapm.capture_span(span_type='tasking_client')
    def task_finished(self, service_task, client_id, service_name,
                      metric_factory):
        exec_time = service_task.get('exec_time')

        try:
            task = ServiceTask(service_task['task'])

            if 'result' in service_task:  # Task created a result
                missing_files = self._handle_task_result(
                    exec_time, task, service_task['result'], client_id,
                    service_name, service_task['freshen'], metric_factory)
                if missing_files:
                    return dict(success=False, missing_files=missing_files)
                return dict(success=True)

            elif 'error' in service_task:  # Task created an error
                error = service_task['error']
                self._handle_task_error(exec_time, task, error, client_id,
                                        service_name, metric_factory)
                return dict(success=True)
            else:
                return None

        except ValueError as e:  # Catch errors when building Task or Result model
            raise e

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_result(self, exec_time: int, task: ServiceTask,
                            result: Dict[str, Any], client_id, service_name,
                            freshen: bool, metric_factory):
        def freshen_file(file_info_list, item):
            file_info = file_info_list.get(item['sha256'], None)
            if file_info is None or not self.filestore.exists(item['sha256']):
                return True
            else:
                file_info['archive_ts'] = archive_ts
                file_info['expiry_ts'] = expiry_ts
                file_info['classification'] = item['classification']
                self.datastore.save_or_freshen_file(
                    item['sha256'],
                    file_info,
                    file_info['expiry_ts'],
                    file_info['classification'],
                    is_section_image=item.get('is_section_image', False))
            return False

        archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive *
                                24 * 60 * 60)
        if task.ttl:
            expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)
        else:
            expiry_ts = None

        # Check if all files are in the filestore
        if freshen:
            missing_files = []
            hashes = list(
                set([
                    f['sha256'] for f in result['response']['extracted'] +
                    result['response']['supplementary']
                ]))
            file_infos = self.datastore.file.multiget(hashes,
                                                      as_obj=False,
                                                      error_on_missing=False)

            with elasticapm.capture_span(
                    name="handle_task_result.freshen_files",
                    span_type="tasking_client"):
                with concurrent.futures.ThreadPoolExecutor(
                        max_workers=5) as executor:
                    res = {
                        f['sha256']: executor.submit(freshen_file, file_infos,
                                                     f)
                        for f in result['response']['extracted'] +
                        result['response']['supplementary']
                    }
                for k, v in res.items():
                    if v.result():
                        missing_files.append(k)

            if missing_files:
                return missing_files

        # Add scores to the heuristics, if any section set a heuristic
        with elasticapm.capture_span(
                name="handle_task_result.process_heuristics",
                span_type="tasking_client"):
            total_score = 0
            for section in result['result']['sections']:
                zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True)
                section['tags'] = flatten(section['tags'])
                if section.get('heuristic'):
                    heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}"
                    section['heuristic']['heur_id'] = heur_id
                    try:
                        section[
                            'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic(
                                section['heuristic'], self.heuristics,
                                zeroize_on_sig_safe)
                        for tag in new_tags:
                            section['tags'].setdefault(tag[0], [])
                            if tag[1] not in section['tags'][tag[0]]:
                                section['tags'][tag[0]].append(tag[1])
                        total_score += section['heuristic']['score']
                    except InvalidHeuristicException:
                        section['heuristic'] = None

        # Update the total score of the result
        result['result']['score'] = total_score

        # Add timestamps for creation, archive and expiry
        result['created'] = now_as_iso()
        result['archive_ts'] = archive_ts
        result['expiry_ts'] = expiry_ts

        # Pop the temporary submission data
        temp_submission_data = result.pop('temp_submission_data', None)
        if temp_submission_data:
            old_submission_data = {
                row.name: row.value
                for row in task.temporary_submission_data
            }
            temp_submission_data = {
                k: v
                for k, v in temp_submission_data.items()
                if k not in old_submission_data or v != old_submission_data[k]
            }
            big_temp_data = {
                k: len(str(v))
                for k, v in temp_submission_data.items()
                if len(str(v)) > self.config.submission.max_temp_data_length
            }
            if big_temp_data:
                big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()]
                self.log.warning(
                    f"[{task.sid}] The following temporary submission keys where ignored because they are "
                    "bigger then the maximum data size allowed "
                    f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}"
                )
                temp_submission_data = {
                    k: v
                    for k, v in temp_submission_data.items()
                    if k not in big_temp_data
                }

        # Process the tag values
        with elasticapm.capture_span(name="handle_task_result.process_tags",
                                     span_type="tasking_client"):
            for section in result['result']['sections']:
                # Perform tag safelisting
                tags, safelisted_tags = self.tag_safelister.get_validated_tag_map(
                    section['tags'])
                section['tags'] = unflatten(tags)
                section['safelisted_tags'] = safelisted_tags

                section['tags'], dropped = construct_safe(
                    Tagging, section.get('tags', {}))

                # Set section score to zero and lower total score if service is set to zeroize score
                # and all tags were safelisted
                if section.pop('zeroize_on_tag_safe', False) and \
                        section.get('heuristic') and \
                        len(tags) == 0 and \
                        len(safelisted_tags) != 0:
                    result['result']['score'] -= section['heuristic']['score']
                    section['heuristic']['score'] = 0

                if dropped:
                    self.log.warning(
                        f"[{task.sid}] Invalid tag data from {service_name}: {dropped}"
                    )

        result = Result(result)
        result_key = result.build_key(
            service_tool_version=result.response.service_tool_version,
            task=task)
        self.dispatch_client.service_finished(task.sid, result_key, result,
                                              temp_submission_data)

        # Metrics
        if result.result.score > 0:
            metric_factory.increment('scored')
        else:
            metric_factory.increment('not_scored')

        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}"
        )

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_error(self, exec_time: int, task: ServiceTask,
                           error: Dict[str, Any], client_id, service_name,
                           metric_factory) -> None:
        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}"
        )

        # Add timestamps for creation, archive and expiry
        error['created'] = now_as_iso()
        error['archive_ts'] = now_as_iso(
            self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
        if task.ttl:
            error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60)

        error = Error(error)
        error_key = error.build_key(
            service_tool_version=error.response.service_tool_version,
            task=task)
        self.dispatch_client.service_failed(task.sid, error_key, error)

        # Metrics
        if error.response.status == 'FAIL_RECOVERABLE':
            metric_factory.increment('fail_recoverable')
        else:
            metric_factory.increment('fail_nonrecoverable')

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
예제 #3
0
class DispatchClient:
    def __init__(self,
                 datastore=None,
                 redis=None,
                 redis_persist=None,
                 logger=None):
        self.config = forge.get_config()

        self.redis = redis or get_client(
            host=self.config.core.redis.nonpersistent.host,
            port=self.config.core.redis.nonpersistent.port,
            private=False,
        )

        redis_persist = redis_persist or get_client(
            host=self.config.core.redis.persistent.host,
            port=self.config.core.redis.persistent.port,
            private=False,
        )

        self.timeout_watcher = WatcherClient(redis_persist)

        self.submission_queue = NamedQueue(SUBMISSION_QUEUE, self.redis)
        self.file_queue = NamedQueue(FILE_QUEUE, self.redis)
        self.ds = datastore or forge.get_datastore(self.config)
        self.log = logger or logging.getLogger(
            "assemblyline.dispatching.client")
        self.results = datastore.result
        self.errors = datastore.error
        self.files = datastore.file
        self.active_submissions = ExpiringHash(DISPATCH_TASK_HASH,
                                               host=redis_persist)
        self.running_tasks = ExpiringHash(DISPATCH_RUNNING_TASK_HASH,
                                          host=self.redis)
        self.service_data = cast(Dict[str, Service],
                                 CachedObject(self._get_services))

    def _get_services(self):
        # noinspection PyUnresolvedReferences
        return {x.name: x for x in self.ds.list_all_services(full=True)}

    def dispatch_submission(self,
                            submission: Submission,
                            completed_queue: str = None):
        """Insert a submission into the dispatching system.

        Note:
            You probably actually want to use the SubmissionTool

        Prerequsits:
            - submission should already be saved in the datastore
            - files should already be in the datastore and filestore
        """
        self.submission_queue.push(
            SubmissionTask(
                dict(
                    submission=submission,
                    completed_queue=completed_queue,
                )).as_primitives())

    def outstanding_services(self, sid) -> Dict[str, int]:
        """
        List outstanding services for a given submission and the number of file each
        of them still have to process.

        :param sid: Submission ID
        :return: Dictionary of services and number of files
                 remaining per services e.g. {"SERVICE_NAME": 1, ... }
        """
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        all_service_status = dispatch_hash.all_results()
        all_files = dispatch_hash.all_files()
        self.log.info(
            f"[{sid}] Listing outstanding services {len(all_files)} files "
            f"and {len(all_service_status)} entries found")

        output: Dict[str, int] = {}

        for file_hash in all_files:
            # The schedule might not be in the cache if the submission or file was just issued,
            # but it will be as soon as the file passes through the dispatcher
            schedule = dispatch_hash.schedules.get(file_hash)
            status_values = all_service_status.get(file_hash, {})

            # Go through the schedule stage by stage so we can react to drops
            # either we have a result and we don't need to count the file (but might drop it)
            # or we don't have a result, and we need to count that file
            while schedule:
                stage = schedule.pop(0)
                for service_name in stage:
                    status = status_values.get(service_name)
                    if status:
                        if status.drop:
                            schedule.clear()
                    else:
                        output[service_name] = output.get(service_name, 0) + 1

        return output

    def request_work(self,
                     worker_id,
                     service_name,
                     service_version,
                     timeout: float = 60,
                     blocking=True) -> Optional[ServiceTask]:
        """Pull work from the service queue for the service in question.

        :param service_version:
        :param worker_id:
        :param service_name: Which service needs work.
        :param timeout: How many seconds to block before returning if blocking is true.
        :param blocking: Whether to wait for jobs to enter the queue, or if false, return immediately
        :return: The job found, and a boolean value indicating if this is the first time this task has
                 been returned by request_work.
        """
        start = time.time()
        remaining = timeout
        while int(remaining) > 0:
            try:
                return self._request_work(worker_id,
                                          service_name,
                                          service_version,
                                          blocking=blocking,
                                          timeout=remaining)
            except RetryRequestWork:
                remaining = timeout - (time.time() - start)
        return None

    def _request_work(self, worker_id, service_name, service_version, timeout,
                      blocking) -> Optional[ServiceTask]:
        # For when we recursively retry on bad task dequeue-ing
        if int(timeout) <= 0:
            self.log.info(
                f"{service_name}:{worker_id} no task returned [timeout]")
            return None

        # Get work from the queue
        work_queue = get_service_queue(service_name, self.redis)
        if blocking:
            result = work_queue.blocking_pop(timeout=int(timeout))
        else:
            result = work_queue.pop(1)
            if result:
                result = result[0]

        if not result:
            self.log.info(
                f"{service_name}:{worker_id} no task returned: [empty message]"
            )
            return None
        task = ServiceTask(result)

        # If someone is supposed to be working on this task right now, we won't be able to add it
        if self.running_tasks.add(task.key(), task.as_primitives()):
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task found"
            )

            process_table = DispatchHash(task.sid, self.redis)

            abandoned = process_table.dispatch_time(
                file_hash=task.fileinfo.sha256, service=task.service_name) == 0
            finished = process_table.finished(
                file_hash=task.fileinfo.sha256,
                service=task.service_name) is not None

            # A service might be re-dispatched as it finishes, when that is the case it can be marked as
            # both finished and dispatched, if that is the case, drop the dispatch from the table
            if finished and not abandoned:
                process_table.drop_dispatch(file_hash=task.fileinfo.sha256,
                                            service=task.service_name)

            if abandoned or finished:
                self.log.info(
                    f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} task already complete"
                )
                self.running_tasks.pop(task.key())
                raise RetryRequestWork()

            # Check if this task has reached the retry limit
            attempt_record = ExpiringHash(f'dispatch-hash-attempts-{task.sid}',
                                          host=self.redis)
            total_attempts = attempt_record.increment(task.key())
            self.log.info(
                f"[{task.sid}/{task.fileinfo.sha256}] {service_name}:{worker_id} "
                f"task attempt {total_attempts}/3")
            if total_attempts > 3:
                self.log.warning(
                    f"[{task.sid}/{task.fileinfo.sha256}] "
                    f"{service_name}:{worker_id} marking task failed: TASK PREEMPTED "
                )
                error = Error(
                    dict(
                        archive_ts=now_as_iso(
                            self.config.datastore.ilm.days_until_archive * 24 *
                            60 * 60),
                        created='NOW',
                        expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                             60) if task.ttl else None,
                        response=dict(
                            message=
                            f'The number of retries has passed the limit.',
                            service_name=task.service_name,
                            service_version=service_version,
                            status='FAIL_NONRECOVERABLE',
                        ),
                        sha256=task.fileinfo.sha256,
                        type="TASK PRE-EMPTED",
                    ))
                error_key = error.build_key(task=task)
                self.service_failed(task.sid, error_key, error)
                export_metrics_once(service_name,
                                    Metrics,
                                    dict(fail_nonrecoverable=1),
                                    host=worker_id,
                                    counter_type='service')
                raise RetryRequestWork()

            # Get the service information
            service_data = self.service_data[task.service_name]
            self.timeout_watcher.touch_task(timeout=int(service_data.timeout),
                                            key=f'{task.sid}-{task.key()}',
                                            worker=worker_id,
                                            task_key=task.key())
            return task
        raise RetryRequestWork()

    def _dispatching_error(self, task, process_table, error):
        error_key = error.build_key(task=task)
        if process_table.add_error(error_key):
            self.errors.save(error_key, error)
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)

    def service_finished(self,
                         sid: str,
                         result_key: str,
                         result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=result.response.service_name,
            sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing successful results."
            )
            return
        task = ServiceTask(task)

        # Check if the service is a candidate for dynamic recursion prevention
        if not task.ignore_dynamic_recursion_prevention:
            service_info = self.service_data.get(result.response.service_name,
                                                 None)
            if service_info and service_info.category == "Dynamic Analysis":
                # TODO: This should be done in lua because it can introduce race condition in the future
                #       but in the meantime it will remain this way while we can confirm it work as expected
                submission = self.active_submissions.get(sid)
                submission['submission']['params']['services'][
                    'runtime_excluded'].append(result.response.service_name)
                self.active_submissions.set(sid, submission)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key,
                                     {"expiry_ts": result.archive_ts})
        else:
            with Lock(f"lock-{result_key}", 5, self.redis):
                old = self.ds.result.get(result_key)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                self.ds.result.save(result_key, result)

        # Let the logs know we have received a result for this task
        if result.drop_file:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key} but processing will stop after this service."
            )
        else:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key}")

        # Store the result object and mark the service finished in the global table
        process_table = DispatchHash(task.sid, self.redis)
        remaining, duplicate = process_table.finish(
            task.fileinfo.sha256, task.service_name, result_key,
            result.result.score, result.classification, result.drop_file)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')
        if duplicate:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already "
                f"completed in the global processing table.")
            return

        # Push the result tags into redis
        new_tags = []
        for section in result.result.sections:
            new_tags.extend(tag_dict_to_list(section.tags.as_primitives()))
        if new_tags:
            tag_set = ExpiringSet(get_tag_set_name(
                sid=task.sid, file_hash=task.fileinfo.sha256),
                                  host=self.redis)
            tag_set.add(*new_tags)

        # Update the temporary data table for this file
        temp_data_hash = ExpiringHash(get_temporary_submission_data_name(
            sid=task.sid, file_hash=task.fileinfo.sha256),
                                      host=self.redis)
        for key, value in (temporary_data or {}).items():
            temp_data_hash.set(key, value)

        # Send the extracted files to the dispatcher
        depth_limit = self.config.submission.max_extraction_depth
        new_depth = task.depth + 1
        if new_depth < depth_limit:
            # Prepare the temporary data from the parent to build the temporary data table for
            # these newly extract files
            parent_data = dict(temp_data_hash.items())

            for extracted_data in result.response.extracted:
                if not process_table.add_file(
                        extracted_data.sha256,
                        task.max_files,
                        parent_hash=task.fileinfo.sha256):
                    if parent_data:
                        child_hash_name = get_temporary_submission_data_name(
                            task.sid, extracted_data.sha256)
                        ExpiringHash(child_hash_name,
                                     host=self.redis).multi_set(parent_data)

                    self._dispatching_error(
                        task, process_table,
                        Error({
                            'archive_ts': result.archive_ts,
                            'expiry_ts': result.expiry_ts,
                            'response': {
                                'message':
                                f"Too many files extracted for submission {task.sid} "
                                f"{extracted_data.sha256} extracted by "
                                f"{task.service_name} will be dropped",
                                'service_name':
                                task.service_name,
                                'service_tool_version':
                                result.response.service_tool_version,
                                'service_version':
                                result.response.service_version,
                                'status':
                                'FAIL_NONRECOVERABLE'
                            },
                            'sha256': extracted_data.sha256,
                            'type': 'MAX FILES REACHED'
                        }))
                    continue
                file_data = self.files.get(extracted_data.sha256)
                self.file_queue.push(
                    FileTask(
                        dict(sid=task.sid,
                             min_classification=task.min_classification.max(
                                 extracted_data.classification).value,
                             file_info=dict(
                                 magic=file_data.magic,
                                 md5=file_data.md5,
                                 mime=file_data.mime,
                                 sha1=file_data.sha1,
                                 sha256=file_data.sha256,
                                 size=file_data.size,
                                 type=file_data.type,
                             ),
                             depth=new_depth,
                             parent_hash=task.fileinfo.sha256,
                             max_files=task.max_files)).as_primitives())
        else:
            for extracted_data in result.response.extracted:
                self._dispatching_error(
                    task, process_table,
                    Error({
                        'archive_ts': result.archive_ts,
                        'expiry_ts': result.expiry_ts,
                        'response': {
                            'message':
                            f"{task.service_name} has extracted a file "
                            f"{extracted_data.sha256} beyond the depth limits",
                            'service_name':
                            result.response.service_name,
                            'service_tool_version':
                            result.response.service_tool_version,
                            'service_version':
                            result.response.service_version,
                            'status':
                            'FAIL_NONRECOVERABLE'
                        },
                        'sha256': extracted_data.sha256,
                        'type': 'MAX DEPTH REACHED'
                    }))

        # If the global table said that this was the last outstanding service,
        # send a message to the dispatchers.
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification.value,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w).push(msg)

    def service_failed(self, sid: str, error_key: str, error: Error):
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=error.response.service_name,
            sha=error.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{error.sha256}] {error.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing an error.")
            return
        task = ServiceTask(task)

        self.log.debug(
            f"[{sid}/{error.sha256}] {task.service_name} Failed with {error.response.status} error."
        )
        remaining = -1
        # Mark the attempt to process the file over in the dispatch table
        process_table = DispatchHash(task.sid, self.redis)
        if error.response.status == "FAIL_RECOVERABLE":
            # Because the error is recoverable, we will not save it nor we will notify the user
            process_table.fail_recoverable(task.fileinfo.sha256,
                                           task.service_name)
        else:
            # This is a NON_RECOVERABLE error, error will be saved and transmitted to the user
            self.errors.save(error_key, error)

            remaining, _duplicate = process_table.fail_nonrecoverable(
                task.fileinfo.sha256, task.service_name, error_key)

            # Send the result key to any watching systems
            msg = {'status': 'FAIL', 'cache_key': error_key}
            for w in self._get_watcher_list(task.sid).members():
                NamedQueue(w).push(msg)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')

        # Send a message to prompt the re-issue of the task if needed
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

    def setup_watch_queue(self, sid):
        """
        This function takes a submission ID as a parameter and creates a unique queue where all service
        result keys for that given submission will be returned to as soon as they come in.

        If the submission is in the middle of processing, this will also send all currently received keys through
        the specified queue so the client that requests the watch queue is up to date.

        :param sid: Submission ID
        :return: The name of the watch queue that was created
        """
        # Create a unique queue
        queue_name = reply_queue_name(prefix="D", suffix="WQ")
        watch_queue = NamedQueue(queue_name, ttl=30)
        watch_queue.push(
            WatchQueueMessage({
                'status': 'START'
            }).as_primitives())

        # Add the newly created queue to the list of queues for the given submission
        self._get_watcher_list(sid).add(queue_name)

        # Push all current keys to the newly created queue (Queue should have a TTL of about 30 sec to 1 minute)
        # Download the entire status table from redis
        dispatch_hash = DispatchHash(sid, self.redis)
        if dispatch_hash.dispatch_count(
        ) == 0 and dispatch_hash.finished_count() == 0:
            # This table is empty? do we have this submission at all?
            submission = self.ds.submission.get(sid)
            if not submission or submission.state == 'completed':
                watch_queue.push(
                    WatchQueueMessage({
                        "status": "STOP"
                    }).as_primitives())
            else:
                # We do have a submission, remind the dispatcher to work on it
                self.submission_queue.push({'sid': sid})

        else:
            all_service_status = dispatch_hash.all_results()
            for status_values in all_service_status.values():
                for status in status_values.values():
                    if status.is_error:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "FAIL",
                                "cache_key": status.key
                            }).as_primitives())
                    else:
                        watch_queue.push(
                            WatchQueueMessage({
                                "status": "OK",
                                "cache_key": status.key
                            }).as_primitives())

        return queue_name

    def _get_watcher_list(self, sid):
        return ExpiringSet(make_watcher_list_name(sid), host=self.redis)
예제 #4
0
    def service_finished(self,
                         sid: str,
                         result_key: str,
                         result: Result,
                         temporary_data: Optional[Dict[str, Any]] = None):
        """Notifies the dispatcher of service completion, and possible new files to dispatch."""
        # Make sure the dispatcher knows we were working on this task
        task_key = ServiceTask.make_key(
            sid=sid,
            service_name=result.response.service_name,
            sha=result.sha256)
        task = self.running_tasks.pop(task_key)
        if not task:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name} could not find the specified "
                f"task in its set of running tasks while processing successful results."
            )
            return
        task = ServiceTask(task)

        # Check if the service is a candidate for dynamic recursion prevention
        if not task.ignore_dynamic_recursion_prevention:
            service_info = self.service_data.get(result.response.service_name,
                                                 None)
            if service_info and service_info.category == "Dynamic Analysis":
                # TODO: This should be done in lua because it can introduce race condition in the future
                #       but in the meantime it will remain this way while we can confirm it work as expected
                submission = self.active_submissions.get(sid)
                submission['submission']['params']['services'][
                    'runtime_excluded'].append(result.response.service_name)
                self.active_submissions.set(sid, submission)

        # Save or freshen the result, the CONTENT of the result shouldn't change, but we need to keep the
        # most distant expiry time to prevent pulling it out from under another submission too early
        if result.is_empty():
            # Empty Result will not be archived therefore result.archive_ts drives their deletion
            self.ds.emptyresult.save(result_key,
                                     {"expiry_ts": result.archive_ts})
        else:
            with Lock(f"lock-{result_key}", 5, self.redis):
                old = self.ds.result.get(result_key)
                if old:
                    if old.expiry_ts and result.expiry_ts:
                        result.expiry_ts = max(result.expiry_ts, old.expiry_ts)
                    else:
                        result.expiry_ts = None
                self.ds.result.save(result_key, result)

        # Let the logs know we have received a result for this task
        if result.drop_file:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key} but processing will stop after this service."
            )
        else:
            self.log.debug(
                f"[{sid}/{result.sha256}] {task.service_name} succeeded. "
                f"Result will be stored in {result_key}")

        # Store the result object and mark the service finished in the global table
        process_table = DispatchHash(task.sid, self.redis)
        remaining, duplicate = process_table.finish(
            task.fileinfo.sha256, task.service_name, result_key,
            result.result.score, result.classification, result.drop_file)
        self.timeout_watcher.clear(f'{task.sid}-{task.key()}')
        if duplicate:
            self.log.warning(
                f"[{sid}/{result.sha256}] {result.response.service_name}'s current task was already "
                f"completed in the global processing table.")
            return

        # Push the result tags into redis
        new_tags = []
        for section in result.result.sections:
            new_tags.extend(tag_dict_to_list(section.tags.as_primitives()))
        if new_tags:
            tag_set = ExpiringSet(get_tag_set_name(
                sid=task.sid, file_hash=task.fileinfo.sha256),
                                  host=self.redis)
            tag_set.add(*new_tags)

        # Update the temporary data table for this file
        temp_data_hash = ExpiringHash(get_temporary_submission_data_name(
            sid=task.sid, file_hash=task.fileinfo.sha256),
                                      host=self.redis)
        for key, value in (temporary_data or {}).items():
            temp_data_hash.set(key, value)

        # Send the extracted files to the dispatcher
        depth_limit = self.config.submission.max_extraction_depth
        new_depth = task.depth + 1
        if new_depth < depth_limit:
            # Prepare the temporary data from the parent to build the temporary data table for
            # these newly extract files
            parent_data = dict(temp_data_hash.items())

            for extracted_data in result.response.extracted:
                if not process_table.add_file(
                        extracted_data.sha256,
                        task.max_files,
                        parent_hash=task.fileinfo.sha256):
                    if parent_data:
                        child_hash_name = get_temporary_submission_data_name(
                            task.sid, extracted_data.sha256)
                        ExpiringHash(child_hash_name,
                                     host=self.redis).multi_set(parent_data)

                    self._dispatching_error(
                        task, process_table,
                        Error({
                            'archive_ts': result.archive_ts,
                            'expiry_ts': result.expiry_ts,
                            'response': {
                                'message':
                                f"Too many files extracted for submission {task.sid} "
                                f"{extracted_data.sha256} extracted by "
                                f"{task.service_name} will be dropped",
                                'service_name':
                                task.service_name,
                                'service_tool_version':
                                result.response.service_tool_version,
                                'service_version':
                                result.response.service_version,
                                'status':
                                'FAIL_NONRECOVERABLE'
                            },
                            'sha256': extracted_data.sha256,
                            'type': 'MAX FILES REACHED'
                        }))
                    continue
                file_data = self.files.get(extracted_data.sha256)
                self.file_queue.push(
                    FileTask(
                        dict(sid=task.sid,
                             min_classification=task.min_classification.max(
                                 extracted_data.classification).value,
                             file_info=dict(
                                 magic=file_data.magic,
                                 md5=file_data.md5,
                                 mime=file_data.mime,
                                 sha1=file_data.sha1,
                                 sha256=file_data.sha256,
                                 size=file_data.size,
                                 type=file_data.type,
                             ),
                             depth=new_depth,
                             parent_hash=task.fileinfo.sha256,
                             max_files=task.max_files)).as_primitives())
        else:
            for extracted_data in result.response.extracted:
                self._dispatching_error(
                    task, process_table,
                    Error({
                        'archive_ts': result.archive_ts,
                        'expiry_ts': result.expiry_ts,
                        'response': {
                            'message':
                            f"{task.service_name} has extracted a file "
                            f"{extracted_data.sha256} beyond the depth limits",
                            'service_name':
                            result.response.service_name,
                            'service_tool_version':
                            result.response.service_tool_version,
                            'service_version':
                            result.response.service_version,
                            'status':
                            'FAIL_NONRECOVERABLE'
                        },
                        'sha256': extracted_data.sha256,
                        'type': 'MAX DEPTH REACHED'
                    }))

        # If the global table said that this was the last outstanding service,
        # send a message to the dispatchers.
        if remaining <= 0:
            self.file_queue.push(
                FileTask(
                    dict(sid=task.sid,
                         min_classification=task.min_classification.value,
                         file_info=task.fileinfo,
                         depth=task.depth,
                         max_files=task.max_files)).as_primitives())

        # Send the result key to any watching systems
        msg = {'status': 'OK', 'cache_key': result_key}
        for w in self._get_watcher_list(task.sid).members():
            NamedQueue(w).push(msg)
예제 #5
0
class RandomService(ServerBase):
    """Replaces everything past the dispatcher.

    Including service API, in the future probably include that in this test.
    """
    def __init__(self, datastore=None, filestore=None):
        super().__init__('assemblyline.randomservice')
        self.config = forge.get_config()
        self.datastore = datastore or forge.get_datastore()
        self.filestore = filestore or forge.get_filestore()
        self.client_id = get_random_id()
        self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60)

        self.counters = {
            n: MetricsFactory('service', Metrics, name=n, config=self.config)
            for n in self.datastore.service_delta.keys()
        }
        self.queues = [
            forge.get_service_queue(name)
            for name in self.datastore.service_delta.keys()
        ]
        self.dispatch_client = DispatchClient(self.datastore)
        self.service_info = CachedObject(self.datastore.list_all_services,
                                         kwargs={'as_obj': False})

    def run(self):
        self.log.info("Random service result generator ready!")
        self.log.info("Monitoring queues:")
        for q in self.queues:
            self.log.info(f"\t{q.name}")

        self.log.info("Waiting for messages...")
        while self.running:
            # Reset Idle flags
            for s in self.service_info:
                if s['enabled']:
                    self.service_state_hash.set(
                        f"{self.client_id}_{s['name']}",
                        (s['name'], ServiceStatus.Idle, time.time() + 30 + 5))

            message = select(*self.queues, timeout=1)
            if not message:
                continue

            archive_ts = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            if self.config.submission.dtl:
                expiry_ts = now_as_iso(self.config.submission.dtl * 24 * 60 *
                                       60)
            else:
                expiry_ts = None
            queue, msg = message
            task = ServiceTask(msg)

            if not self.dispatch_client.running_tasks.add(
                    task.key(), task.as_primitives()):
                continue

            # Set service busy flag
            self.service_state_hash.set(
                f"{self.client_id}_{task.service_name}",
                (task.service_name, ServiceStatus.Running,
                 time.time() + 30 + 5))

            # METRICS
            self.counters[task.service_name].increment('execute')
            # METRICS (not caching here so always miss)
            self.counters[task.service_name].increment('cache_miss')

            self.log.info(
                f"\tQueue {queue} received a new task for sid {task.sid}.")
            action = random.randint(1, 10)
            if action >= 2:
                if action > 8:
                    result = random_minimal_obj(Result)
                else:
                    result = random_model_obj(Result)
                result.sha256 = task.fileinfo.sha256
                result.response.service_name = task.service_name
                result.archive_ts = archive_ts
                result.expiry_ts = expiry_ts
                result.response.extracted = result.response.extracted[task.
                                                                      depth +
                                                                      2:]
                result.response.supplementary = result.response.supplementary[
                    task.depth + 2:]
                result_key = Result.help_build_key(
                    sha256=task.fileinfo.sha256,
                    service_name=task.service_name,
                    service_version='0',
                    is_empty=result.is_empty())

                self.log.info(
                    f"\t\tA result was generated for this task: {result_key}")

                new_files = result.response.extracted + result.response.supplementary
                for f in new_files:
                    if not self.datastore.file.get(f.sha256):
                        random_file = random_model_obj(File)
                        random_file.archive_ts = archive_ts
                        random_file.expiry_ts = expiry_ts
                        random_file.sha256 = f.sha256
                        self.datastore.file.save(f.sha256, random_file)
                    if not self.filestore.exists(f.sha256):
                        self.filestore.put(f.sha256, f.sha256)

                time.sleep(random.randint(0, 2))

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)

                # METRICS
                if result.result.score > 0:
                    self.counters[task.service_name].increment('scored')
                else:
                    self.counters[task.service_name].increment('not_scored')

            else:
                error = random_model_obj(Error)
                error.archive_ts = archive_ts
                error.expiry_ts = expiry_ts
                error.sha256 = task.fileinfo.sha256
                error.response.service_name = task.service_name
                error.type = random.choice(
                    ["EXCEPTION", "SERVICE DOWN", "SERVICE BUSY"])

                error_key = error.build_key('0')

                self.log.info(
                    f"\t\tA {error.response.status}:{error.type} "
                    f"error was generated for this task: {error_key}")

                self.dispatch_client.service_failed(task.sid, error_key, error)

                # METRICS
                if error.response.status == "FAIL_RECOVERABLE":
                    self.counters[task.service_name].increment(
                        'fail_recoverable')
                else:
                    self.counters[task.service_name].increment(
                        'fail_nonrecoverable')