Ejemplo n.º 1
0
class Plumber(CoreBase):
    def __init__(self, logger=None, shutdown_timeout: float = None, config=None,
                 redis=None, redis_persist=None, datastore=None, delay=60):
        super().__init__('plumber', logger, shutdown_timeout, config=config, redis=redis,
                         redis_persist=redis_persist, datastore=datastore)
        self.delay = float(delay)
        self.dispatch_client = DispatchClient(datastore=self.datastore, redis=self.redis,
                                              redis_persist=self.redis_persist, logger=self.log)

    def try_run(self):
        # Get an initial list of all the service queues
        service_queues = {queue.decode('utf-8').lstrip('service-queue-'): None
                          for queue in self.redis.keys(service_queue_name('*'))}

        while self.running:
            self.heartbeat()
            # Reset the status of the service queues
            service_queues = {service_name: None for service_name in service_queues}

            # Update the service queue status based on current list of services
            for service in self.datastore.list_all_services(full=True):
                service_queues[service.name] = service

            for service_name, service in service_queues.items():
                if not service or not service.enabled or self.get_service_stage(service_name) != ServiceStage.Running:
                    while True:
                        task = self.dispatch_client.request_work(None, service_name=service_name,
                                                                 service_version='0', blocking=False)
                        if task is None:
                            break

                        error = Error(dict(
                            archive_ts=now_as_iso(self.config.datastore.ilm.days_until_archive * 24 * 60 * 60),
                            created='NOW',
                            expiry_ts=now_as_iso(task.ttl * 24 * 60 * 60) if task.ttl else None,
                            response=dict(
                                message='The service was disabled while processing this task.',
                                service_name=task.service_name,
                                service_version='0',
                                status='FAIL_NONRECOVERABLE',
                            ),
                            sha256=task.fileinfo.sha256,
                            type="TASK PRE-EMPTED",
                        ))

                        error_key = error.build_key(task=task)

                        self.dispatch_client.service_failed(task.sid, error_key, error)

            # Wait a while before checking status of all services again
            time.sleep(self.delay)
Ejemplo n.º 2
0
class TaskingClient:
    """A helper class to simplify tasking for privileged services and service-server.

    This tool helps take care of interactions between the filestore,
    datastore, dispatcher, and any sources of files to be processed.
    """
    def __init__(self,
                 datastore: AssemblylineDatastore = None,
                 filestore: FileStore = None,
                 config=None,
                 redis=None,
                 redis_persist=None,
                 identify=None):
        self.log = logging.getLogger('assemblyline.tasking_client')
        self.config = config or forge.CachedObject(forge.get_config)
        self.datastore = datastore or forge.get_datastore(self.config)
        self.dispatch_client = DispatchClient(self.datastore,
                                              redis=redis,
                                              redis_persist=redis_persist)
        self.event_sender = EventSender('changes.services', redis)
        self.filestore = filestore or forge.get_filestore(self.config)
        self.heuristic_handler = HeuristicHandler(self.datastore)
        self.heuristics = {
            h.heur_id: h
            for h in self.datastore.list_all_heuristics()
        }
        self.status_table = ExpiringHash(SERVICE_STATE_HASH,
                                         ttl=60 * 30,
                                         host=redis)
        self.tag_safelister = forge.CachedObject(forge.get_tag_safelister,
                                                 kwargs=dict(
                                                     log=self.log,
                                                     config=config,
                                                     datastore=self.datastore),
                                                 refresh=300)
        if identify:
            self.cleanup = False
        else:
            self.cleanup = True
        self.identify = identify or forge.get_identify(
            config=self.config, datastore=self.datastore, use_cache=True)

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.stop()

    def stop(self):
        if self.cleanup:
            self.identify.stop()

    @elasticapm.capture_span(span_type='tasking_client')
    def upload_file(self,
                    file_path,
                    classification,
                    ttl,
                    is_section_image,
                    expected_sha256=None):
        # Identify the file info of the uploaded file
        file_info = self.identify.fileinfo(file_path)

        # Validate SHA256 of the uploaded file
        if expected_sha256 is None or expected_sha256 == file_info['sha256']:
            file_info['archive_ts'] = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            file_info['classification'] = classification
            if ttl:
                file_info['expiry_ts'] = now_as_iso(ttl * 24 * 60 * 60)
            else:
                file_info['expiry_ts'] = None

            # Update the datastore with the uploaded file
            self.datastore.save_or_freshen_file(
                file_info['sha256'],
                file_info,
                file_info['expiry_ts'],
                file_info['classification'],
                is_section_image=is_section_image)

            # Upload file to the filestore (upload already checks if the file exists)
            self.filestore.upload(file_path, file_info['sha256'])
        else:
            raise TaskingClientException(
                "Uploaded file does not match expected file hash. "
                f"[{file_info['sha256']} != {expected_sha256}]")

    # Service
    @elasticapm.capture_span(span_type='tasking_client')
    def register_service(self, service_data, log_prefix=""):
        keep_alive = True

        try:
            # Get heuristics list
            heuristics = service_data.pop('heuristics', None)

            # Patch update_channel, registry_type before Service registration object creation
            service_data['update_channel'] = service_data.get(
                'update_channel',
                self.config.services.preferred_update_channel)
            service_data['docker_config']['registry_type'] = service_data['docker_config'] \
                .get('registry_type', self.config.services.preferred_registry_type)
            service_data['privileged'] = service_data.get(
                'privileged', self.config.services.prefer_service_privileged)
            for dep in service_data.get('dependencies', {}).values():
                dep['container']['registry_type'] = dep.get(
                    'registry_type',
                    self.config.services.preferred_registry_type)

            # Pop unused registration service_data
            for x in ['file_required', 'tool_version']:
                service_data.pop(x, None)

            # Create Service registration object
            service = Service(service_data)

            # Fix service version, we don't need to see the stable label
            service.version = service.version.replace('stable', '')

            # Save service if it doesn't already exist
            if not self.datastore.service.exists(
                    f'{service.name}_{service.version}'):
                self.datastore.service.save(
                    f'{service.name}_{service.version}', service)
                self.datastore.service.commit()
                self.log.info(f"{log_prefix}{service.name} registered")
                keep_alive = False

            # Save service delta if it doesn't already exist
            if not self.datastore.service_delta.exists(service.name):
                self.datastore.service_delta.save(service.name,
                                                  {'version': service.version})
                self.datastore.service_delta.commit()
                self.log.info(f"{log_prefix}{service.name} "
                              f"version ({service.version}) registered")

            new_heuristics = []
            if heuristics:
                plan = self.datastore.heuristic.get_bulk_plan()
                for index, heuristic in enumerate(heuristics):
                    heuristic_id = f'#{index}'  # Set heuristic id to it's position in the list for logging purposes
                    try:
                        # Append service name to heuristic ID
                        heuristic[
                            'heur_id'] = f"{service.name.upper()}.{str(heuristic['heur_id'])}"

                        # Attack_id field is now a list, make it a list if we receive otherwise
                        attack_id = heuristic.get('attack_id', None)
                        if isinstance(attack_id, str):
                            heuristic['attack_id'] = [attack_id]

                        heuristic = Heuristic(heuristic)
                        heuristic_id = heuristic.heur_id
                        plan.add_upsert_operation(heuristic_id, heuristic)
                    except Exception as e:
                        msg = f"{service.name} has an invalid heuristic ({heuristic_id}): {str(e)}"
                        self.log.exception(f"{log_prefix}{msg}")
                        raise ValueError(msg)

                for item in self.datastore.heuristic.bulk(plan)['items']:
                    if item['update']['result'] != "noop":
                        new_heuristics.append(item['update']['_id'])
                        self.log.info(
                            f"{log_prefix}{service.name} "
                            f"heuristic {item['update']['_id']}: {item['update']['result'].upper()}"
                        )

                self.datastore.heuristic.commit()

            service_config = self.datastore.get_service_with_delta(
                service.name, as_obj=False)

            # Notify components watching for service config changes
            self.event_sender.send(service.name, {
                'operation': Operation.Added,
                'name': service.name
            })

        except ValueError as e:  # Catch errors when building Service or Heuristic model(s)
            raise e

        return dict(keep_alive=keep_alive,
                    new_heuristics=new_heuristics,
                    service_config=service_config or dict())

    # Task
    @elasticapm.capture_span(span_type='tasking_client')
    def get_task(self,
                 client_id,
                 service_name,
                 service_version,
                 service_tool_version,
                 metric_factory,
                 status_expiry=None,
                 timeout=30):
        if status_expiry is None:
            status_expiry = time.time() + timeout

        cache_found = False

        try:
            service_data = self.dispatch_client.service_data[service_name]
        except KeyError:
            raise ServiceMissingException(
                "The service you're asking task for does not exist, try later",
                404)

        # Set the service status to Idle since we will be waiting for a task
        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, status_expiry))

        # Getting a new task
        task = self.dispatch_client.request_work(client_id,
                                                 service_name,
                                                 service_version,
                                                 timeout=timeout)

        if not task:
            # We've reached the timeout and no task found in service queue
            return None, False

        # We've got a task to process, consider us busy
        self.status_table.set(client_id, (service_name, ServiceStatus.Running,
                                          time.time() + service_data.timeout))
        metric_factory.increment('execute')

        result_key = Result.help_build_key(
            sha256=task.fileinfo.sha256,
            service_name=service_name,
            service_version=service_version,
            service_tool_version=service_tool_version,
            is_empty=False,
            task=task)

        # If we are allowed, try to see if the result has been cached
        if not task.ignore_cache and not service_data.disable_cache:
            # Checking for previous results for this key
            result = self.datastore.result.get_if_exists(result_key)
            if result:
                metric_factory.increment('cache_hit')
                if result.result.score:
                    metric_factory.increment('scored')
                else:
                    metric_factory.increment('not_scored')

                result.archive_ts = now_as_iso(
                    self.config.datastore.ilm.days_until_archive * 24 * 60 *
                    60)
                if task.ttl:
                    result.expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)
                cache_found = True

            if not cache_found:
                # Checking for previous empty results for this key
                result = self.datastore.emptyresult.get_if_exists(
                    f"{result_key}.e")
                if result:
                    metric_factory.increment('cache_hit')
                    metric_factory.increment('not_scored')
                    result = self.datastore.create_empty_result_from_key(
                        result_key)
                    self.dispatch_client.service_finished(
                        task.sid, f"{result_key}.e", result)
                    cache_found = True

            if not cache_found:
                metric_factory.increment('cache_miss')

        else:
            metric_factory.increment('cache_skipped')

        if not cache_found:
            # No luck with the cache, lets dispatch the task to a client
            return task.as_primitives(), False

        return None, True

    @elasticapm.capture_span(span_type='tasking_client')
    def task_finished(self, service_task, client_id, service_name,
                      metric_factory):
        exec_time = service_task.get('exec_time')

        try:
            task = ServiceTask(service_task['task'])

            if 'result' in service_task:  # Task created a result
                missing_files = self._handle_task_result(
                    exec_time, task, service_task['result'], client_id,
                    service_name, service_task['freshen'], metric_factory)
                if missing_files:
                    return dict(success=False, missing_files=missing_files)
                return dict(success=True)

            elif 'error' in service_task:  # Task created an error
                error = service_task['error']
                self._handle_task_error(exec_time, task, error, client_id,
                                        service_name, metric_factory)
                return dict(success=True)
            else:
                return None

        except ValueError as e:  # Catch errors when building Task or Result model
            raise e

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_result(self, exec_time: int, task: ServiceTask,
                            result: Dict[str, Any], client_id, service_name,
                            freshen: bool, metric_factory):
        def freshen_file(file_info_list, item):
            file_info = file_info_list.get(item['sha256'], None)
            if file_info is None or not self.filestore.exists(item['sha256']):
                return True
            else:
                file_info['archive_ts'] = archive_ts
                file_info['expiry_ts'] = expiry_ts
                file_info['classification'] = item['classification']
                self.datastore.save_or_freshen_file(
                    item['sha256'],
                    file_info,
                    file_info['expiry_ts'],
                    file_info['classification'],
                    is_section_image=item.get('is_section_image', False))
            return False

        archive_ts = now_as_iso(self.config.datastore.ilm.days_until_archive *
                                24 * 60 * 60)
        if task.ttl:
            expiry_ts = now_as_iso(task.ttl * 24 * 60 * 60)
        else:
            expiry_ts = None

        # Check if all files are in the filestore
        if freshen:
            missing_files = []
            hashes = list(
                set([
                    f['sha256'] for f in result['response']['extracted'] +
                    result['response']['supplementary']
                ]))
            file_infos = self.datastore.file.multiget(hashes,
                                                      as_obj=False,
                                                      error_on_missing=False)

            with elasticapm.capture_span(
                    name="handle_task_result.freshen_files",
                    span_type="tasking_client"):
                with concurrent.futures.ThreadPoolExecutor(
                        max_workers=5) as executor:
                    res = {
                        f['sha256']: executor.submit(freshen_file, file_infos,
                                                     f)
                        for f in result['response']['extracted'] +
                        result['response']['supplementary']
                    }
                for k, v in res.items():
                    if v.result():
                        missing_files.append(k)

            if missing_files:
                return missing_files

        # Add scores to the heuristics, if any section set a heuristic
        with elasticapm.capture_span(
                name="handle_task_result.process_heuristics",
                span_type="tasking_client"):
            total_score = 0
            for section in result['result']['sections']:
                zeroize_on_sig_safe = section.pop('zeroize_on_sig_safe', True)
                section['tags'] = flatten(section['tags'])
                if section.get('heuristic'):
                    heur_id = f"{service_name.upper()}.{str(section['heuristic']['heur_id'])}"
                    section['heuristic']['heur_id'] = heur_id
                    try:
                        section[
                            'heuristic'], new_tags = self.heuristic_handler.service_heuristic_to_result_heuristic(
                                section['heuristic'], self.heuristics,
                                zeroize_on_sig_safe)
                        for tag in new_tags:
                            section['tags'].setdefault(tag[0], [])
                            if tag[1] not in section['tags'][tag[0]]:
                                section['tags'][tag[0]].append(tag[1])
                        total_score += section['heuristic']['score']
                    except InvalidHeuristicException:
                        section['heuristic'] = None

        # Update the total score of the result
        result['result']['score'] = total_score

        # Add timestamps for creation, archive and expiry
        result['created'] = now_as_iso()
        result['archive_ts'] = archive_ts
        result['expiry_ts'] = expiry_ts

        # Pop the temporary submission data
        temp_submission_data = result.pop('temp_submission_data', None)
        if temp_submission_data:
            old_submission_data = {
                row.name: row.value
                for row in task.temporary_submission_data
            }
            temp_submission_data = {
                k: v
                for k, v in temp_submission_data.items()
                if k not in old_submission_data or v != old_submission_data[k]
            }
            big_temp_data = {
                k: len(str(v))
                for k, v in temp_submission_data.items()
                if len(str(v)) > self.config.submission.max_temp_data_length
            }
            if big_temp_data:
                big_data_sizes = [f"{k}={v}" for k, v in big_temp_data.items()]
                self.log.warning(
                    f"[{task.sid}] The following temporary submission keys where ignored because they are "
                    "bigger then the maximum data size allowed "
                    f"[{self.config.submission.max_temp_data_length}]: {' | '.join(big_data_sizes)}"
                )
                temp_submission_data = {
                    k: v
                    for k, v in temp_submission_data.items()
                    if k not in big_temp_data
                }

        # Process the tag values
        with elasticapm.capture_span(name="handle_task_result.process_tags",
                                     span_type="tasking_client"):
            for section in result['result']['sections']:
                # Perform tag safelisting
                tags, safelisted_tags = self.tag_safelister.get_validated_tag_map(
                    section['tags'])
                section['tags'] = unflatten(tags)
                section['safelisted_tags'] = safelisted_tags

                section['tags'], dropped = construct_safe(
                    Tagging, section.get('tags', {}))

                # Set section score to zero and lower total score if service is set to zeroize score
                # and all tags were safelisted
                if section.pop('zeroize_on_tag_safe', False) and \
                        section.get('heuristic') and \
                        len(tags) == 0 and \
                        len(safelisted_tags) != 0:
                    result['result']['score'] -= section['heuristic']['score']
                    section['heuristic']['score'] = 0

                if dropped:
                    self.log.warning(
                        f"[{task.sid}] Invalid tag data from {service_name}: {dropped}"
                    )

        result = Result(result)
        result_key = result.build_key(
            service_tool_version=result.response.service_tool_version,
            task=task)
        self.dispatch_client.service_finished(task.sid, result_key, result,
                                              temp_submission_data)

        # Metrics
        if result.result.score > 0:
            metric_factory.increment('scored')
        else:
            metric_factory.increment('not_scored')

        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"successfully completed task {f' in {exec_time}ms' if exec_time else ''}"
        )

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))

    @elasticapm.capture_span(span_type='tasking_client')
    def _handle_task_error(self, exec_time: int, task: ServiceTask,
                           error: Dict[str, Any], client_id, service_name,
                           metric_factory) -> None:
        self.log.info(
            f"[{task.sid}] {client_id} - {service_name} "
            f"failed to complete task {f' in {exec_time}ms' if exec_time else ''}"
        )

        # Add timestamps for creation, archive and expiry
        error['created'] = now_as_iso()
        error['archive_ts'] = now_as_iso(
            self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
        if task.ttl:
            error['expiry_ts'] = now_as_iso(task.ttl * 24 * 60 * 60)

        error = Error(error)
        error_key = error.build_key(
            service_tool_version=error.response.service_tool_version,
            task=task)
        self.dispatch_client.service_failed(task.sid, error_key, error)

        # Metrics
        if error.response.status == 'FAIL_RECOVERABLE':
            metric_factory.increment('fail_recoverable')
        else:
            metric_factory.increment('fail_nonrecoverable')

        self.status_table.set(
            client_id, (service_name, ServiceStatus.Idle, time.time() + 5))
def test_simple(clean_redis, clean_datastore):
    ds = clean_datastore
    redis = clean_redis

    def service_queue(name):
        return get_service_queue(name, redis)

    file = random_model_obj(File)
    file_hash = file.sha256
    file.type = 'unknown'
    ds.file.save(file_hash, file)

    sub: Submission = random_model_obj(models.submission.Submission)
    sub.sid = sid = 'first-submission'
    sub.params.ignore_cache = False
    sub.params.max_extracted = 5
    sub.params.classification = get_classification().UNRESTRICTED
    sub.params.initial_data = json.dumps({'cats': 'big'})
    sub.files = [dict(sha256=file_hash, name='file')]

    disp = Dispatcher(ds, redis, redis)
    disp.running = ToggleTrue()
    client = DispatchClient(ds, redis, redis)
    client.dispatcher_data_age = time.time()
    client.dispatcher_data.append(disp.instance_id)

    # Submit a problem, and check that it gets added to the dispatch hash
    # and the right service queues
    logger.info('==== first dispatch')
    # task = SubmissionTask(sub.as_primitives(), 'some-completion-queue')
    client.dispatch_submission(sub)
    disp.pull_submissions()
    disp.service_worker(disp.process_queue_index(sid))
    task = disp.tasks.get(sid)

    assert task.queue_keys[(file_hash, 'extract')] is not None
    assert task.queue_keys[(file_hash, 'wrench')] is not None
    assert service_queue('extract').length() == 1
    assert service_queue('wrench').length() == 1

    # Making the same call again will queue it up again
    logger.info('==== second dispatch')
    disp.dispatch_file(task, file_hash)

    assert task.queue_keys[(file_hash, 'extract')] is not None
    assert task.queue_keys[(file_hash, 'wrench')] is not None
    assert service_queue('extract').length() == 1  # the queue doesn't pile up
    assert service_queue('wrench').length() == 1

    logger.info('==== third dispatch')
    job = client.request_work('0', 'extract', '0')
    assert job.temporary_submission_data == [{'name': 'cats', 'value': 'big'}]
    client.service_failed(sid, 'abc123', make_error(file_hash, 'extract'))
    # Deliberately do in the wrong order to make sure that works
    disp.pull_service_results()
    disp.service_worker(disp.process_queue_index(sid))

    assert task.queue_keys[(file_hash, 'extract')] is not None
    assert task.queue_keys[(file_hash, 'wrench')] is not None
    assert service_queue('extract').length() == 1

    # Mark extract as finished, wrench as failed
    logger.info('==== fourth dispatch')
    client.request_work('0', 'extract', '0')
    client.request_work('0', 'wrench', '0')
    client.service_finished(sid, 'extract-result',
                            make_result(file_hash, 'extract'))
    client.service_failed(sid, 'wrench-error',
                          make_error(file_hash, 'wrench', False))
    for _ in range(2):
        disp.pull_service_results()
        disp.service_worker(disp.process_queue_index(sid))

    assert wait_error(task, file_hash, 'wrench')
    assert wait_result(task, file_hash, 'extract')
    assert service_queue('av-a').length() == 1
    assert service_queue('av-b').length() == 1
    assert service_queue('frankenstrings').length() == 1

    # Have the AVs fail, frankenstrings finishes
    logger.info('==== fifth dispatch')
    client.request_work('0', 'av-a', '0')
    client.request_work('0', 'av-b', '0')
    client.request_work('0', 'frankenstrings', '0')
    client.service_failed(sid, 'av-a-error',
                          make_error(file_hash, 'av-a', False))
    client.service_failed(sid, 'av-b-error',
                          make_error(file_hash, 'av-b', False))
    client.service_finished(sid, 'f-result',
                            make_result(file_hash, 'frankenstrings'))
    for _ in range(3):
        disp.pull_service_results()
        disp.service_worker(disp.process_queue_index(sid))

    assert wait_result(task, file_hash, 'frankenstrings')
    assert wait_error(task, file_hash, 'av-a')
    assert wait_error(task, file_hash, 'av-b')
    assert service_queue('xerox').length() == 1

    # Finish the xerox service and check if the submission completion got checked
    logger.info('==== sixth dispatch')
    client.request_work('0', 'xerox', '0')
    client.service_finished(sid, 'xerox-result-key',
                            make_result(file_hash, 'xerox'))
    disp.pull_service_results()
    disp.service_worker(disp.process_queue_index(sid))
    disp.save_submission()

    assert wait_result(task, file_hash, 'xerox')
    assert disp.tasks.get(sid) is None
class Plumber(CoreBase):
    def __init__(self,
                 logger=None,
                 shutdown_timeout: float = None,
                 config=None,
                 redis=None,
                 redis_persist=None,
                 datastore=None,
                 delay=60):
        super().__init__('plumber',
                         logger,
                         shutdown_timeout,
                         config=config,
                         redis=redis,
                         redis_persist=redis_persist,
                         datastore=datastore)
        self.delay = float(delay)
        self.dispatch_client = DispatchClient(datastore=self.datastore,
                                              redis=self.redis,
                                              redis_persist=self.redis_persist,
                                              logger=self.log)

        self.flush_threads: dict[str, threading.Thread] = {}
        self.stop_signals: dict[str, threading.Event] = {}
        self.service_limit: dict[str, int] = {}

    def stop(self):
        for sig in self.stop_signals.values():
            sig.set()
        super().stop()

    def try_run(self):
        # Get an initial list of all the service queues
        service_queues: dict[str, Optional[Service]]
        service_queues = {
            queue.decode('utf-8').lstrip('service-queue-'): None
            for queue in self.redis.keys(service_queue_name('*'))
        }

        while self.running:
            # Reset the status of the service queues
            service_queues = {
                service_name: None
                for service_name in service_queues
            }

            # Update the service queue status based on current list of services
            for service in self.datastore.list_all_services(full=True):
                service_queues[service.name] = service

            for service_name, service in service_queues.items():
                # For disabled or othewise unavailable services purge the queue
                current_stage = self.get_service_stage(service_name,
                                                       ServiceStage.Running)
                if not service or not service.enabled or current_stage != ServiceStage.Running:
                    while True:
                        task = self.dispatch_client.request_work(
                            'plumber',
                            service_name=service_name,
                            service_version='0',
                            blocking=False)
                        if task is None:
                            break

                        error = Error(
                            dict(
                                archive_ts=now_as_iso(self.config.datastore.
                                                      ilm.days_until_archive *
                                                      24 * 60 * 60),
                                created='NOW',
                                expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                                     60) if task.ttl else None,
                                response=dict(
                                    message=
                                    'The service was disabled while processing this task.',
                                    service_name=task.service_name,
                                    service_version='0',
                                    status='FAIL_NONRECOVERABLE',
                                ),
                                sha256=task.fileinfo.sha256,
                                type="TASK PRE-EMPTED",
                            ))

                        error_key = error.build_key(task=task)
                        self.dispatch_client.service_failed(
                            task.sid, error_key, error)
                        self.heartbeat()

                # For services that are enabled but limited
                if not service or not service.enabled or service.max_queue_length == 0:
                    if service_name in self.stop_signals:
                        self.stop_signals[service_name].set()
                        self.service_limit.pop(service_name)
                        self.flush_threads.pop(service_name)
                elif service and service.enabled and service.max_queue_length > 0:
                    self.service_limit[service_name] = service.max_queue_length
                    thread = self.flush_threads.get(service_name)
                    if not thread or not thread.is_alive():
                        self.stop_signals[service_name] = threading.Event()
                        thread = threading.Thread(target=self.watch_service,
                                                  args=[service_name],
                                                  daemon=True)
                        self.flush_threads[service_name] = thread
                        thread.start()

            # Wait a while before checking status of all services again
            self.sleep_with_heartbeat(self.delay)

    def watch_service(self, service_name):
        service_queue = get_service_queue(service_name, self.redis)
        while self.running and not self.stop_signals[service_name].is_set():
            while service_queue.length() > self.service_limit[service_name]:
                task = self.dispatch_client.request_work(
                    'plumber',
                    service_name=service_name,
                    service_version='0',
                    blocking=False,
                    low_priority=True)
                if task is None:
                    break

                error = Error(
                    dict(
                        archive_ts=now_as_iso(
                            self.config.datastore.ilm.days_until_archive * 24 *
                            60 * 60),
                        created='NOW',
                        expiry_ts=now_as_iso(task.ttl * 24 * 60 *
                                             60) if task.ttl else None,
                        response=dict(
                            message="Task canceled due to execesive queuing.",
                            service_name=task.service_name,
                            service_version='0',
                            status='FAIL_NONRECOVERABLE',
                        ),
                        sha256=task.fileinfo.sha256,
                        type="TASK PRE-EMPTED",
                    ))

                error_key = error.build_key(task=task)
                self.dispatch_client.service_failed(task.sid, error_key, error)
            self.sleep(2)
Ejemplo n.º 5
0
class MockService(ServerBase):
    """Replaces everything past the dispatcher.

    Including service API, in the future probably include that in this test.
    """
    def __init__(self, name, datastore, redis, filestore):
        super().__init__('assemblyline.service.' + name)
        self.service_name = name
        self.datastore = datastore
        self.filestore = filestore
        self.queue = get_service_queue(name, redis)
        self.dispatch_client = DispatchClient(self.datastore, redis)
        self.hits = dict()
        self.drops = dict()

    def try_run(self):
        while self.running:
            task = self.dispatch_client.request_work('worker',
                                                     self.service_name,
                                                     '0',
                                                     timeout=1)
            if not task:
                continue
            print(self.service_name, 'has received a job', task.sid)

            file = self.filestore.get(task.fileinfo.sha256)

            instructions = json.loads(file)
            instructions = instructions.get(self.service_name, {})
            print(self.service_name, 'following instruction:', instructions)
            hits = self.hits[task.fileinfo.sha256] = self.hits.get(
                task.fileinfo.sha256, 0) + 1

            if instructions.get('semaphore', False):
                _global_semaphore.acquire(blocking=True,
                                          timeout=instructions['semaphore'])
                continue

            if 'drop' in instructions:
                if instructions['drop'] >= hits:
                    self.drops[task.fileinfo.sha256] = self.drops.get(
                        task.fileinfo.sha256, 0) + 1
                    continue

            if instructions.get('failure', False):
                error = Error(instructions['error'])
                error.sha256 = task.fileinfo.sha256
                self.dispatch_client.service_failed(task.sid,
                                                    error=error,
                                                    error_key=get_random_id())
                continue

            result_data = {
                'archive_ts': time.time() + 300,
                'classification': 'U',
                'response': {
                    'service_version': '0',
                    'service_tool_version': '0',
                    'service_name': self.service_name,
                },
                'result': {},
                'sha256': task.fileinfo.sha256,
                'expiry_ts': time.time() + 600
            }

            result_data.update(instructions.get('result', {}))
            result_data['response'].update(instructions.get('response', {}))

            result = Result(result_data)
            result_key = instructions.get('result_key', get_random_id())
            self.dispatch_client.service_finished(task.sid, result_key, result)
Ejemplo n.º 6
0
class RandomService(ServerBase):
    """Replaces everything past the dispatcher.

    Including service API, in the future probably include that in this test.
    """
    def __init__(self, datastore=None, filestore=None):
        super().__init__('assemblyline.randomservice')
        self.config = forge.get_config()
        self.datastore = datastore or forge.get_datastore()
        self.filestore = filestore or forge.get_filestore()
        self.client_id = get_random_id()
        self.service_state_hash = ExpiringHash(SERVICE_STATE_HASH, ttl=30 * 60)

        self.counters = {
            n: MetricsFactory('service', Metrics, name=n, config=self.config)
            for n in self.datastore.service_delta.keys()
        }
        self.queues = [
            forge.get_service_queue(name)
            for name in self.datastore.service_delta.keys()
        ]
        self.dispatch_client = DispatchClient(self.datastore)
        self.service_info = CachedObject(self.datastore.list_all_services,
                                         kwargs={'as_obj': False})

    def run(self):
        self.log.info("Random service result generator ready!")
        self.log.info("Monitoring queues:")
        for q in self.queues:
            self.log.info(f"\t{q.name}")

        self.log.info("Waiting for messages...")
        while self.running:
            # Reset Idle flags
            for s in self.service_info:
                if s['enabled']:
                    self.service_state_hash.set(
                        f"{self.client_id}_{s['name']}",
                        (s['name'], ServiceStatus.Idle, time.time() + 30 + 5))

            message = select(*self.queues, timeout=1)
            if not message:
                continue

            archive_ts = now_as_iso(
                self.config.datastore.ilm.days_until_archive * 24 * 60 * 60)
            if self.config.submission.dtl:
                expiry_ts = now_as_iso(self.config.submission.dtl * 24 * 60 *
                                       60)
            else:
                expiry_ts = None
            queue, msg = message
            task = ServiceTask(msg)

            if not self.dispatch_client.running_tasks.add(
                    task.key(), task.as_primitives()):
                continue

            # Set service busy flag
            self.service_state_hash.set(
                f"{self.client_id}_{task.service_name}",
                (task.service_name, ServiceStatus.Running,
                 time.time() + 30 + 5))

            # METRICS
            self.counters[task.service_name].increment('execute')
            # METRICS (not caching here so always miss)
            self.counters[task.service_name].increment('cache_miss')

            self.log.info(
                f"\tQueue {queue} received a new task for sid {task.sid}.")
            action = random.randint(1, 10)
            if action >= 2:
                if action > 8:
                    result = random_minimal_obj(Result)
                else:
                    result = random_model_obj(Result)
                result.sha256 = task.fileinfo.sha256
                result.response.service_name = task.service_name
                result.archive_ts = archive_ts
                result.expiry_ts = expiry_ts
                result.response.extracted = result.response.extracted[task.
                                                                      depth +
                                                                      2:]
                result.response.supplementary = result.response.supplementary[
                    task.depth + 2:]
                result_key = Result.help_build_key(
                    sha256=task.fileinfo.sha256,
                    service_name=task.service_name,
                    service_version='0',
                    is_empty=result.is_empty())

                self.log.info(
                    f"\t\tA result was generated for this task: {result_key}")

                new_files = result.response.extracted + result.response.supplementary
                for f in new_files:
                    if not self.datastore.file.get(f.sha256):
                        random_file = random_model_obj(File)
                        random_file.archive_ts = archive_ts
                        random_file.expiry_ts = expiry_ts
                        random_file.sha256 = f.sha256
                        self.datastore.file.save(f.sha256, random_file)
                    if not self.filestore.exists(f.sha256):
                        self.filestore.put(f.sha256, f.sha256)

                time.sleep(random.randint(0, 2))

                self.dispatch_client.service_finished(task.sid, result_key,
                                                      result)

                # METRICS
                if result.result.score > 0:
                    self.counters[task.service_name].increment('scored')
                else:
                    self.counters[task.service_name].increment('not_scored')

            else:
                error = random_model_obj(Error)
                error.archive_ts = archive_ts
                error.expiry_ts = expiry_ts
                error.sha256 = task.fileinfo.sha256
                error.response.service_name = task.service_name
                error.type = random.choice(
                    ["EXCEPTION", "SERVICE DOWN", "SERVICE BUSY"])

                error_key = error.build_key('0')

                self.log.info(
                    f"\t\tA {error.response.status}:{error.type} "
                    f"error was generated for this task: {error_key}")

                self.dispatch_client.service_failed(task.sid, error_key, error)

                # METRICS
                if error.response.status == "FAIL_RECOVERABLE":
                    self.counters[task.service_name].increment(
                        'fail_recoverable')
                else:
                    self.counters[task.service_name].increment(
                        'fail_nonrecoverable')