예제 #1
0
    def save(self, *args, **kwargs):
        from eventkit_cloud.tasks.enumerations import TaskState

        if hasattr(self, "status"):
            if self.status and TaskState[self.status] == TaskState.RUNNING:
                self.started_at = timezone.now()
            if self.status and TaskState[self.status] in TaskState.get_finished_states():
                self.finished_at = timezone.now()
        super(TimeTrackingModelMixin, self).save(*args, **kwargs)
예제 #2
0
    def after_return(self, status, retval, task_id, args, kwargs, einfo):
        # This will only run in the PCF environment to shut down unused workers.
        super(EventKitBaseTask, self).after_return(status, retval, task_id,
                                                   args, kwargs, einfo)
        pcf_scaling = os.getenv("PCF_SCALING", False)
        if pcf_scaling:
            from eventkit_cloud.tasks.util_tasks import shutdown_celery_workers

            queue_type, hostname = self.request.hostname.split("@")

            # In our current setup the queue name always mirrors the routing_key, if this changes this logic will break.
            queue_name = self.request.delivery_info["routing_key"]
            if not getattr(settings, "CELERY_SCALE_BY_RUN"):
                logger.info(
                    f"{self.name} has completed, sending shutdown_celery_workers task to queue {queue_name}."
                )
                if os.getenv("CELERY_TASK_APP"):
                    app_name = os.getenv("CELERY_TASK_APP")
                else:
                    app_name = json.loads(os.getenv(
                        "VCAP_APPLICATION", "{}")).get("application_name")

                if os.getenv("PCF_SCALING"):
                    client = Pcf()
                    client.login()
                else:
                    client = Docker()
                    app_name = settings.DOCKER_IMAGE_NAME

                # The message was a generic shutdown sent to a specific queue_name.
                if not (hostname or queue_type):
                    queue_type, hostname = self.request.hostname.split("@")

                workers = [f"{queue_type}@{hostname}", f"priority@{hostname}"]
                if queue_type in ["run", "scale"]:
                    return {"action": "skip_shutdown", "workers": workers}
                messages = get_message_count(queue_name)
                running_tasks_by_queue = client.get_running_tasks(
                    app_name, queue_name)
                print(f"RUNNING TASKS BY QUEUE: {running_tasks_by_queue}")
                running_tasks_by_queue_count = running_tasks_by_queue[
                    "pagination"]["total_results"]
                export_tasks = ExportTaskRecord.objects.filter(
                    worker=hostname,
                    status__in=[
                        task_state.value
                        for task_state in TaskState.get_not_finished_states()
                    ])

                if not export_tasks:
                    if running_tasks_by_queue_count > messages or (
                            running_tasks_by_queue == 0 and messages == 0):
                        shutdown_celery_workers.s().apply_async(
                            queue=queue_name, routing_key=queue_name)
                        # return value is unused but useful for storing in the celery result.
                        return {"action": "shutdown", "workers": workers}
예제 #3
0
def rerun_data_provider_records(run_uid, user_id, data_provider_slugs):
    from eventkit_cloud.tasks.task_factory import create_run, Error, Unauthorized, InvalidLicense

    with transaction.atomic():
        old_run: ExportRun = ExportRun.objects.select_related(
            "job__user", "parent_run__job__user").get(uid=run_uid)

        user: User = User.objects.get(pk=user_id)

        while old_run and old_run.is_cloning:
            # Find pending providers and add them to list
            for dptr in old_run.data_provider_task_records.all():
                if dptr.status == TaskState.PENDING.value:
                    data_provider_slugs.append(dptr.provider.slug)
            old_run: ExportRun = old_run.parent_run

        # Remove any duplicates
        data_provider_slugs = list(set(data_provider_slugs))

        try:
            new_run_uid = create_run(job=old_run.job,
                                     user=user,
                                     clone=old_run,
                                     download_data=False)
        except Unauthorized:
            raise PermissionDenied(
                code="permission_denied",
                detail="ADMIN permission is required to run this DataPack.")
        except (InvalidLicense, Error) as err:
            return Response([{
                "detail": _(str(err))
            }], status.HTTP_400_BAD_REQUEST)

        run: ExportRun = ExportRun.objects.get(uid=new_run_uid)

        # Reset the old data provider task record for the providers we're recreating.
        data_provider_task_record: DataProviderTaskRecord
        run.data_provider_task_records.filter(slug="run").delete()
        for data_provider_task_record in run.data_provider_task_records.all():
            if data_provider_task_record.provider is not None:
                # Have to clean out the tasks that were finished and request the ones that weren't.
                if (data_provider_task_record.provider.slug
                        in data_provider_slugs
                        or TaskState[data_provider_task_record.status]
                        in TaskState.get_not_finished_states()):
                    data_provider_task_record.status = TaskState.PENDING.value
                    # Delete the associated tasks so that they can be recreated.
                    data_provider_task_record.tasks.all().delete()
                    data_provider_task_record.save()

        run.status = TaskState.SUBMITTED.value
        run.save()
예제 #4
0
    def after_return(self, status, retval, task_id, args, kwargs, einfo):
        # This will only run in the PCF environment to shut down unused workers.
        super(EventKitBaseTask, self).after_return(status, retval, task_id,
                                                   args, kwargs, einfo)
        pcf_scaling = settings.PCF_SCALING  # type: ignore  # issue with django-stubs
        if pcf_scaling:
            from eventkit_cloud.tasks.scheduled_tasks import kill_worker  # type: ignore

            queue_type, hostname = self.request.hostname.split("@")

            # In our current setup the queue name always mirrors the routing_key, if this changes this logic will break.
            queue_name = self.request.delivery_info["routing_key"]
            if not getattr(settings, "CELERY_SCALE_BY_RUN"):
                logger.info(
                    f"{self.name} has completed, shutting down queue {queue_name}."
                )
                client, app_name = get_scale_client()

                # The message was a generic shutdown sent to a specific queue_name.
                if not (hostname or queue_type):
                    queue_type, hostname = self.request.hostname.split("@")

                workers = [f"{queue_type}@{hostname}", f"priority@{hostname}"]
                if queue_type in ["run", "scale"]:
                    return {"action": "skip_shutdown", "workers": workers}
                messages = get_message_count(queue_name)
                running_tasks_by_queue = client.get_running_tasks(
                    app_name, queue_name)
                print(f"RUNNING TASKS BY QUEUE: {running_tasks_by_queue}")
                running_tasks_by_queue_count = running_tasks_by_queue[
                    "pagination"]["total_results"]
                export_tasks = ExportTaskRecord.objects.filter(
                    worker=hostname,
                    status__in=[
                        task_state.value
                        for task_state in TaskState.get_not_finished_states()
                    ])

                if not export_tasks:
                    if running_tasks_by_queue_count > messages or (
                            running_tasks_by_queue == 0 and messages == 0):
                        kill_worker(task_name=queue_name, client=client)
                        # return value is unused but useful for storing in the celery result.
                        return {"action": "shutdown", "workers": workers}
예제 #5
0
    def parse_tasks(
        self,
        worker=None,
        run_uid=None,
        user_details=None,
        run_zip_file_slug_sets=None,
        session_token=None,
        queue_group=None,
    ):
        """
        This handles all of the logic for taking the information about what individual celery tasks and groups
        them under specific providers.

        Each Provider (e.g. OSM) gets a chain:  OSM_TASK -> FORMAT_TASKS = PROVIDER_SUBTASK_CHAIN
        They need to be finalized (was the task successful?) to update the database state:
            PROVIDER_SUBTASK_CHAIN -> FINALIZE_PROVIDER_TASK

        We also have an optional chain of tasks that get processed after the providers are run:
            AD_HOC_TASK1 -> AD_HOC_TASK2 -> FINALIZE_RUN_TASK = FINALIZE_RUN_TASK_COLLECTION

        If the PROVIDER_SUBTASK_CHAIN fails it needs to be cleaned up.  The clean up task also calls the
        finalize provider task. This is because when a task fails the failed task will call an on_error (link_error)
        task and never return.
            PROVIDER_SUBTASK_CHAIN -> FINALIZE_PROVIDER_TASK
                   |
                   v
                CLEAN_UP_FAILURE_TASK -> FINALIZE_PROVIDER_TASK

        Now there needs to be someway for the finalize tasks to be called.  Since we now have several a possible
        forked path, we need each path to check the state of the providers to see if they are all finished before
        moving on.
        It would be great if celery would implicitly handled that, but it doesn't ever merge the forked paths.
        So we add a WAIT_FOR_PROVIDERS task to check state once the providers are ready they call the final tasks.

        PROVIDER_SUBTASK_CHAIN -> FINALIZE_PROVIDER_TASK -> WAIT_FOR_PROVIDERS   \
                   |                                                              ==> FINALIZE_RUN_TASK_COLLECTION
                   v                                                             /
            CLEAN_UP_FAILURE_TASK -> FINALIZE_PROVIDER_TASK -> WAIT_FOR_PROVIDERS


        :param worker: A worker node (hostname) for a celery worker, this should match the node name used when starting,
         the celery worker.
        :param run_uid: A uid to reference an ExportRun.
        :return: The AsyncResult from the celery chain of all tasks for this run.
        """
        # This is just to make it easier to trace when user_details haven't been sent
        if user_details is None:
            user_details = {"username": "******"}

        if not run_uid:
            raise Exception("Cannot parse_tasks without a run uid.")

        run = ExportRun.objects.prefetch_related(
            "job__projections", "job__data_provider_tasks", "data_provider_task_records"
        ).get(uid=run_uid)
        job = run.job
        run_dir = get_run_staging_dir(run.uid)

        wait_for_providers_settings = {
            "queue": f"{queue_group}.priority",
            "routing_key": f"{queue_group}.priority",
            "priority": TaskPriority.FINALIZE_PROVIDER.value,
        }

        finalize_task_settings = {
            "interval": 4,
            "max_retries": 10,
            "queue": f"{queue_group}.priority",
            "routing_key": f"{queue_group}.priority",
            "priority": TaskPriority.FINALIZE_RUN.value,
        }

        finalized_provider_task_chain_list = []
        # Create a task record which can hold tasks for the run (datapack)
        run_task_record, created = DataProviderTaskRecord.objects.get_or_create(
            run=run, name="run", slug="run", defaults={"status": TaskState.PENDING.value, "display": False}
        )
        if created:
            logger.info("New data provider task record created")
            run_task_record.status = TaskState.PENDING.value
            run_task_record.save()

        run_zip_task_chain = get_zip_task_chain(
            data_provider_task_record_uid=run_task_record.uid,
            worker=worker,
        )
        for data_provider_task in job.data_provider_tasks.all():

            data_provider_task_record = run.data_provider_task_records.filter(
                provider__slug=data_provider_task.provider.slug
            ).first()
            if (
                data_provider_task_record
                and TaskState[data_provider_task_record.status] in TaskState.get_finished_states()
            ):
                continue

            if self.type_task_map.get(data_provider_task.provider.export_provider_type.type_name):
                # Each task builder has a primary task which pulls the source data, grab that task here...
                type_name = data_provider_task.provider.export_provider_type.type_name

                primary_export_task = self.type_task_map.get(type_name)

                stage_dir = get_provider_staging_dir(run_dir, data_provider_task.provider.slug)
                args = {
                    "primary_export_task": primary_export_task,
                    "user": job.user,
                    "provider_task_uid": data_provider_task.uid,
                    "stage_dir": stage_dir,
                    "run": run,
                    "service_type": data_provider_task.provider.export_provider_type.type_name,
                    "worker": worker,
                    "user_details": user_details,
                    "session_token": session_token,
                }

                (
                    provider_task_record_uid,
                    provider_subtask_chain,
                ) = TaskChainBuilder().build_tasks(**args)

                wait_for_providers_signature = wait_for_providers_task.s(
                    run_uid=run_uid,
                    locking_task_key=run_uid,
                    callback_task=create_finalize_run_task_collection(
                        run_uid=run_uid,
                        run_provider_task_record_uid=run_task_record.uid,
                        run_zip_task_chain=run_zip_task_chain,
                        run_zip_file_slug_sets=run_zip_file_slug_sets,
                        apply_args=finalize_task_settings,
                    ),
                    apply_args=finalize_task_settings,
                ).set(**wait_for_providers_settings)

                if provider_subtask_chain:
                    # The finalize_export_provider_task will check all of the export tasks
                    # for this provider and save the export provider's status.

                    selection_task = create_task(
                        data_provider_task_record_uid=provider_task_record_uid,
                        worker=worker,
                        stage_dir=stage_dir,
                        task=output_selection_geojson_task,
                        selection=job.the_geom.geojson,
                        user_details=user_details,
                    )

                    # create signature to close out the provider tasks
                    finalize_export_provider_signature = finalize_export_provider_task.s(
                        data_provider_task_uid=provider_task_record_uid,
                        status=TaskState.COMPLETED.value,
                        locking_task_key=run_uid,
                    ).set(**finalize_task_settings)

                    # add zip if required
                    # skip zip if there is only one source in the data pack (they would be redundant files).
                    if data_provider_task.provider.zip and len(job.data_provider_tasks.all()) > 1:
                        zip_export_provider_sig = get_zip_task_chain(
                            data_provider_task_record_uid=provider_task_record_uid,
                            data_provider_task_record_uids=[provider_task_record_uid],
                            worker=worker,
                        )
                        provider_subtask_chain = chain(provider_subtask_chain, zip_export_provider_sig)

                    finalized_provider_task_chain_list.append(
                        chain(
                            selection_task,
                            provider_subtask_chain,
                            finalize_export_provider_signature,
                            wait_for_providers_signature,
                        )
                    )

        # we kick off all of the sub-tasks at once down here rather than one at a time in the for loop above so
        # that if an error occurs earlier on in the method, all of the tasks will fail rather than an undefined
        # number of them. this simplifies error handling, because we don't have to deduce which tasks were
        # successfully kicked off and which ones failed.
        for item in finalized_provider_task_chain_list:
            item.apply_async(**finalize_task_settings)
예제 #6
0
def scale_by_runs(max_tasks_memory):
    """
    @param max_tasks_memory: The amount of memory in MB to allow for all of the tasks.
    @type max_tasks_memory: int
    """
    from audit_logging.utils import get_user_details

    client, app_name = get_scale_client()

    celery_task_details = get_celery_task_details(client, app_name)
    running_tasks_memory = int(celery_task_details["memory"])
    celery_tasks = get_celery_tasks_scale_by_run()

    # Check if we need to scale for default system tasks.
    scale_default_tasks(client, app_name, celery_tasks)

    # Get run in progress
    runs = ExportRun.objects.filter(status=TaskState.SUBMITTED.value,
                                    deleted=False)
    total_tasks = 0
    running_tasks = client.get_running_tasks(app_name)
    logger.info(f"Running tasks: {running_tasks}")

    if running_tasks:
        total_tasks = running_tasks["pagination"].get("total_results", 0)
        # Get a list of running task names excluding the default celery tasks.
        running_task_names = [
            resource.get("name") for resource in running_tasks.get("resources")
            if resource.get("name") != "celery"
        ]
        finished_runs = ExportRun.objects.filter(
            Q(uid__in=running_task_names)
            & (Q(status__in=[
                state.value for state in TaskState.get_finished_states()
            ]) | Q(deleted=True)))

        finished_run_uids = []
        for finished_run in finished_runs:
            logger.info(
                f"Stopping {finished_run.uid} because it is in a finished state ({finished_run.status}) "
                f"or was deleted ({finished_run.deleted}).")
            finished_run_uids.append(str(finished_run.uid))
        kill_workers(task_names=finished_run_uids, client=client)

    for run in runs:
        celery_run_task = copy.deepcopy(celery_tasks["run"])

        logger.info(
            f"Checking to see if submitted run {run.uid} needs a new worker.")
        max_runs = int(os.getenv("RUNS_CONCURRENCY", 3))

        if max_runs and total_tasks >= max_runs:
            logger.info(
                f"total_tasks ({total_tasks}) >= max_runs ({max_runs})")
            break
        if running_tasks_memory + celery_run_task["memory"] >= max_tasks_memory:
            logger.info("Not enough available memory to scale another run.")
            break
        task_name = run.uid

        running_tasks_by_queue = client.get_running_tasks(app_name, task_name)
        running_tasks_by_queue_count = running_tasks_by_queue[
            "pagination"].get("total_results", 0)

        logger.info(
            f"Currently {running_tasks_by_queue_count} tasks running for {task_name}."
        )
        if running_tasks_by_queue_count:
            logger.info(f"Already a consumer for {task_name}")
            continue
        user_session = UserSession.objects.filter(user=run.user).last()
        session_token = None
        if user_session:
            session = Session.objects.get(session_key=user_session.session_id)
            session_token = session.get_decoded().get("session_token")

        user_details = get_user_details(run.user)
        pick_up_run_task.s(run_uid=str(run.uid),
                           session_token=session_token,
                           user_details=user_details).apply_async(
                               queue=str(task_name),
                               routing_key=str(task_name))
        celery_run_task["command"] = celery_run_task["command"].format(
            celery_group_name=task_name)
        run_task_command(client, app_name, str(task_name), celery_run_task)
        # Keep track of new resources being used.
        total_tasks += 1
        running_tasks_memory += celery_run_task["memory"]
예제 #7
0
 def estimated_finish(self):
     if TaskState[self.status] in TaskState.get_finished_states():
         return
     return get_cache_value(obj=self,
                            attribute="estimated_finish",
                            default=0)
예제 #8
0
 def progress(self):
     if TaskState[self.status] in TaskState.get_finished_states():
         return 100
     return get_cache_value(obj=self, attribute="progress", default=0)