Exemplo n.º 1
0
def success_branch__callable(*, dag_run: DagRun, **kwargs):
    """
    Function to determine if we should follow the quarantine or
    the archive branch. If no downstream tasks have failed we follow
    archive branch and quarantine otherwise.
    """
    previous_task_failures = [
        dag_run.get_task_instance(task_id).state == "failed"
        for task_id in ["init", "extract", "transform", "load"]
    ]

    logger.info(f"Dag run: {dag_run}")

    if any(previous_task_failures):
        branch = "quarantine"
    else:
        branch = "archive"

    return branch
    def skip_task_instance(self):
        """Skip the specified task instance and downstream tasks.
        Obtain task instance from session according to dag_id, run_id and task_id,
        define the state of this task instance as SKIPPED.

        args:
            dag_id: dag id
            run_id: the run id of dag run
            task_id: the task id of task instance of dag
        """
        logging.info("Executing custom 'skip_task_instance' function")

        dag_id = self.get_argument(request, 'dag_id')
        run_id = self.get_argument(request, 'run_id')
        task_id = self.get_argument(request, 'task_id')

        session = settings.Session()
        query = session.query(DagRun)
        dag_run = query.filter(
            DagRun.dag_id == dag_id,
            DagRun.run_id == run_id
        ).first()

        if dag_run is None:
            return ApiResponse.not_found("dag run is not found")

        logging.info('dag_run:' + str(dag_run))

        task_instance = DagRun.get_task_instance(dag_run, task_id)

        if task_instance is None:
            return ApiResponse.not_found("dag task is not found")

        logging.info('task_instance:' + str(task_instance))

        task_instance.state = State.SKIPPED
        session.merge(task_instance)
        session.commit()
        session.close()

        return ApiResponse.success()
    def task_instance_detail(self):
        """Obtain task_instance from session according to dag_id, run_id and task_id,
        and return taskId, dagId, state, tryNumber, maxTries, startDate, endDate, duration fields in task_instance

        args:
            dag_id: dag id
            run_id: the run id of dag run
            task_id: the task id of task instance of dag
        """
        logging.info("Executing custom 'task_instance_detail' function")

        dag_id = self.get_argument(request, 'dag_id')
        run_id = self.get_argument(request, 'run_id')
        task_id = self.get_argument(request, 'task_id')

        session = settings.Session()
        query = session.query(DagRun)
        dag_run = query.filter(
            DagRun.dag_id == dag_id,
            DagRun.run_id == run_id
        ).first()

        if dag_run is None:
            return ApiResponse.not_found("dag run is not found")

        logging.info('dag_run:' + str(dag_run))

        task_instance = DagRun.get_task_instance(dag_run, task_id)

        if task_instance is None:
            return ApiResponse.not_found("dag task is not found")

        logging.info('task_instance:' + str(task_instance))

        res_task_instance = ResponseFormat.format_dag_task(task_instance)
        session.close()

        return ApiResponse.success(res_task_instance)
    def kill_running_tasks(self):
        """Stop running the specified task instance and downstream tasks.
        Obtain task_instance from session according to dag_id, run_id and task_id,
        If task_id is not empty, get task_instance with RUNNIN or NONE status from dag_run according to task_id,
          and set task_instance status to FAILED.
        If task_id is empty, get all task_instances whose status is RUNNIN or NONE from dag_run,
          and set the status of these task_instances to FAILED.

        args:
            dag_id: dag id
            run_id: the run id of dag run
            task_id: the task id of task instance of dag
        """
        logging.info("Executing custom 'kill_running_tasks' function")

        dagbag = self.get_dagbag()

        dag_id = self.get_argument(request, 'dag_id')
        run_id = self.get_argument(request, 'run_id')
        task_id = self.get_argument(request, 'task_id')

        session = settings.Session()
        query = session.query(DagRun)
        dag_run = query.filter(
            DagRun.dag_id == dag_id,
            DagRun.run_id == run_id
        ).first()

        if dag_run is None:
            return ApiResponse.not_found("dag run is not found")

        if dag_id not in dagbag.dags:
            return ApiResponse.bad_request("Dag id {} not found".format(dag_id))

        dag = dagbag.get_dag(dag_id)
        logging.info('dag: ' + str(dag))
        logging.info('dag_subdag: ' + str(dag.subdags))

        tis = []
        if task_id:
            task_instance = DagRun.get_task_instance(dag_run, task_id)
            if task_instance is None or task_instance.state not in [State.RUNNING, State.NONE]:
                return ApiResponse.not_found("task is not found or state is neither RUNNING nor NONE")
            else:
                tis.append(task_instance)
        else:
            tis = DagRun.get_task_instances(dag_run, [State.RUNNING, State.NONE])

        logging.info('tis: ' + str(tis))
        running_task_count = len(tis)

        if running_task_count > 0:
            for ti in tis:
                ti.state = State.FAILED
                ti.end_date = timezone.utcnow()
                session.merge(ti)
                session.commit()
        else:
            return ApiResponse.not_found("dagRun don't have running tasks")

        session.close()

        return ApiResponse.success()