def success_branch__callable(*, dag_run: DagRun, **kwargs): """ Function to determine if we should follow the quarantine or the archive branch. If no downstream tasks have failed we follow archive branch and quarantine otherwise. """ previous_task_failures = [ dag_run.get_task_instance(task_id).state == "failed" for task_id in ["init", "extract", "transform", "load"] ] logger.info(f"Dag run: {dag_run}") if any(previous_task_failures): branch = "quarantine" else: branch = "archive" return branch
def skip_task_instance(self): """Skip the specified task instance and downstream tasks. Obtain task instance from session according to dag_id, run_id and task_id, define the state of this task instance as SKIPPED. args: dag_id: dag id run_id: the run id of dag run task_id: the task id of task instance of dag """ logging.info("Executing custom 'skip_task_instance' function") dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') task_id = self.get_argument(request, 'task_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") logging.info('dag_run:' + str(dag_run)) task_instance = DagRun.get_task_instance(dag_run, task_id) if task_instance is None: return ApiResponse.not_found("dag task is not found") logging.info('task_instance:' + str(task_instance)) task_instance.state = State.SKIPPED session.merge(task_instance) session.commit() session.close() return ApiResponse.success()
def task_instance_detail(self): """Obtain task_instance from session according to dag_id, run_id and task_id, and return taskId, dagId, state, tryNumber, maxTries, startDate, endDate, duration fields in task_instance args: dag_id: dag id run_id: the run id of dag run task_id: the task id of task instance of dag """ logging.info("Executing custom 'task_instance_detail' function") dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') task_id = self.get_argument(request, 'task_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") logging.info('dag_run:' + str(dag_run)) task_instance = DagRun.get_task_instance(dag_run, task_id) if task_instance is None: return ApiResponse.not_found("dag task is not found") logging.info('task_instance:' + str(task_instance)) res_task_instance = ResponseFormat.format_dag_task(task_instance) session.close() return ApiResponse.success(res_task_instance)
def kill_running_tasks(self): """Stop running the specified task instance and downstream tasks. Obtain task_instance from session according to dag_id, run_id and task_id, If task_id is not empty, get task_instance with RUNNIN or NONE status from dag_run according to task_id, and set task_instance status to FAILED. If task_id is empty, get all task_instances whose status is RUNNIN or NONE from dag_run, and set the status of these task_instances to FAILED. args: dag_id: dag id run_id: the run id of dag run task_id: the task id of task instance of dag """ logging.info("Executing custom 'kill_running_tasks' function") dagbag = self.get_dagbag() dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') task_id = self.get_argument(request, 'task_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") if dag_id not in dagbag.dags: return ApiResponse.bad_request("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) logging.info('dag: ' + str(dag)) logging.info('dag_subdag: ' + str(dag.subdags)) tis = [] if task_id: task_instance = DagRun.get_task_instance(dag_run, task_id) if task_instance is None or task_instance.state not in [State.RUNNING, State.NONE]: return ApiResponse.not_found("task is not found or state is neither RUNNING nor NONE") else: tis.append(task_instance) else: tis = DagRun.get_task_instances(dag_run, [State.RUNNING, State.NONE]) logging.info('tis: ' + str(tis)) running_task_count = len(tis) if running_task_count > 0: for ti in tis: ti.state = State.FAILED ti.end_date = timezone.utcnow() session.merge(ti) session.commit() else: return ApiResponse.not_found("dagRun don't have running tasks") session.close() return ApiResponse.success()