def get_dag_runs(dag_id, state=None): """ Returns a list of Dag Runs for a specific DAG ID. :param dag_id: String identifier of a DAG :param state: queued|running|success... :return: List of DAG runs of a DAG with requested state, or all runs if the state is not specified """ dagbag = DagBag() # Check DAG exists. if dag_id not in dagbag.dags: error_message = "Dag id {} not found".format(dag_id) raise AirflowException(error_message) dag_runs = list() state = state.lower() if state else None for run in DagRun.find(dag_id=dag_id, state=state): dag_runs.append({ 'id': run.id, 'run_id': run.run_id, 'state': run.state, 'dag_id': run.dag_id, 'execution_date': run.execution_date.isoformat(), 'start_date': ((run.start_date or '') and run.start_date.isoformat()), 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date) }) return dag_runs
def test_overwrite_params_with_dag_run_conf(self): task = DummyOperator(task_id='op') ti = TI(task=task, execution_date=datetime.datetime.now()) dag_run = DagRun() dag_run.conf = {"override": True} params = {"override": False} ti.overwrite_params_with_dag_run_conf(params, dag_run) self.assertEqual(True, params["override"])
def _get_dep_statuses(self, ti, session, dep_context): dag = ti.task.dag dagrun = ti.get_dagrun(session) if not dagrun: # The import is needed here to avoid a circular dependency from airflow.models import DagRun running_dagruns = DagRun.find( dag_id=dag.dag_id, state=State.RUNNING, external_trigger=False, session=session ) if len(running_dagruns) >= dag.max_active_runs: reason = ("The maximum number of active dag runs ({0}) for this task " "instance's DAG '{1}' has been reached.".format( dag.max_active_runs, ti.dag_id)) else: reason = "Unknown reason" yield self._failing_status( reason="Task instance's dagrun did not exist: {0}.".format(reason)) else: if dagrun.state != State.RUNNING: yield self._failing_status( reason="Task instance's dagrun was not in the 'running' state but in " "the state '{}'.".format(dagrun.state))
def trigger_dag(dag_id, run_id=None, conf=None, execution_date=None): dagbag = DagBag() if dag_id not in dagbag.dags: raise AirflowException("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) if not execution_date: execution_date = datetime.now() if not run_id: run_id = "manual__{0}".format(execution_date.isoformat()) dr = DagRun.find(dag_id=dag_id, run_id=run_id) if dr: raise AirflowException("Run id {} already exists for dag id {}".format( run_id, dag_id )) run_conf = None if conf: run_conf = json.loads(conf) trigger = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True ) return trigger
def trigger_dag(args): dag = get_dag(args) if not dag: logging.error("Cannot find dag {}".format(args.dag_id)) sys.exit(1) execution_date = datetime.now() run_id = args.run_id or "manual__{0}".format(execution_date.isoformat()) dr = DagRun.find(dag_id=args.dag_id, run_id=run_id) if dr: logging.error("This run_id {} already exists".format(run_id)) raise AirflowException() run_conf = {} if args.conf: run_conf = json.loads(args.conf) trigger = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True ) logging.info("Created {}".format(trigger))
def set_dag_run_state(dag, execution_date, state=State.SUCCESS, commit=False): """ Set the state of a dag run and all task instances associated with the dag run for a specific execution date. :param dag: the DAG of which to alter state :param execution_date: the execution date from which to start looking :param state: the state to which the DAG need to be set :param commit: commit DAG and tasks to be altered to the database :return: list of tasks that have been created and updated :raises: AssertionError if dag or execution_date is invalid """ res = [] if not dag or not execution_date: return res # Mark all task instances in the dag run for task in dag.tasks: task.dag = dag new_state = set_state(task=task, execution_date=execution_date, state=state, commit=commit) res.extend(new_state) # Mark the dag run if commit: drs = DagRun.find(dag.dag_id, execution_date=execution_date) for dr in drs: dr.dag = dag dr.update_state() return res
def dag_state(args): """ Returns the state of a DagRun at the command line. >>> airflow dag_state tutorial 2015-01-01T00:00:00.000000 running """ dag = get_dag(args) dr = DagRun.find(dag.dag_id, execution_date=args.execution_date) print(dr[0].state if len(dr) > 0 else None)
def evaluate_dagrun( self, dag_id, expected_task_states, # dict of task_id: state dagrun_state, run_kwargs=None, advance_execution_date=False, session=None): """ Helper for testing DagRun states with simple two-task DAGS. This is hackish: a dag run is created but its tasks are run by a backfill. """ if run_kwargs is None: run_kwargs = {} scheduler = SchedulerJob(**self.default_scheduler_args) dag = self.dagbag.get_dag(dag_id) dag.clear() dr = scheduler.create_dag_run(dag) if advance_execution_date: # run a second time to schedule a dagrun after the start_date dr = scheduler.create_dag_run(dag) ex_date = dr.execution_date try: dag.run(start_date=ex_date, end_date=ex_date, **run_kwargs) except AirflowException: pass # test tasks for task_id, expected_state in expected_task_states.items(): task = dag.get_task(task_id) ti = TI(task, ex_date) ti.refresh_from_db() self.assertEqual(ti.state, expected_state) # load dagrun dr = DagRun.find(dag_id=dag_id, execution_date=ex_date) dr = dr[0] dr.dag = dag # dagrun is running self.assertEqual(dr.state, State.RUNNING) dr.update_state() # dagrun failed self.assertEqual(dr.state, dagrun_state)
def latest_dag_runs(): """Returns the latest DagRun for each DAG formatted for the UI. """ from airflow.models import DagRun dagruns = DagRun.get_latest_runs() payload = [] for dagrun in dagruns: if dagrun.execution_date: payload.append({ 'dag_id': dagrun.dag_id, 'execution_date': dagrun.execution_date.strftime("%Y-%m-%d %H:%M"), 'start_date': ((dagrun.start_date or '') and dagrun.start_date.strftime("%Y-%m-%d %H:%M")), 'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id, execution_date=dagrun.execution_date) }) return jsonify(items=payload) # old flask versions dont support jsonifying arrays
def _create_dagruns(dag, execution_dates, state, run_id_template): """ Infers from the dates which dag runs need to be created and does so. :param dag: the dag to create dag runs for :param execution_dates: list of execution dates to evaluate :param state: the state to set the dag run to :param run_id_template:the template for run id to be with the execution date :return: newly created and existing dag runs for the execution dates supplied """ # find out if we need to create any dag runs drs = DagRun.find(dag_id=dag.dag_id, execution_date=execution_dates) dates_to_create = list(set(execution_dates) - set([dr.execution_date for dr in drs])) for date in dates_to_create: dr = dag.create_dagrun( run_id=run_id_template.format(date.isoformat()), execution_date=date, start_date=timezone.utcnow(), external_trigger=False, state=state, ) drs.append(dr) return drs
class TestDecorators(unittest.TestCase): EXAMPLE_DAG_DEFAULT_DATE = dates.days_ago(2) run_id = "test_{}".format(DagRun.id_for_date(EXAMPLE_DAG_DEFAULT_DATE)) @classmethod def setUpClass(cls): cls.dagbag = DagBag(include_examples=True) app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] cls.app = app.test_client() def setUp(self): self.session = Session() self.cleanup_dagruns() self.prepare_dagruns() def cleanup_dagruns(self): DR = DagRun dag_ids = 'example_bash_operator' (self.session .query(DR) .filter(DR.dag_id == dag_ids) .filter(DR.run_id == self.run_id) .delete(synchronize_session='fetch')) self.session.commit() def prepare_dagruns(self): self.bash_dag = self.dagbag.dags['example_bash_operator'] self.bash_dag.sync_to_db() self.bash_dagrun = self.bash_dag.create_dagrun( run_id=self.run_id, execution_date=self.EXAMPLE_DAG_DEFAULT_DATE, start_date=timezone.utcnow(), state=State.RUNNING) def check_last_log(self, dag_id, event, execution_date=None): qry = self.session.query(Log.dag_id, Log.task_id, Log.event, Log.execution_date, Log.owner, Log.extra) qry = qry.filter(Log.dag_id == dag_id, Log.event == event) if execution_date: qry = qry.filter(Log.execution_date == execution_date) logs = qry.order_by(Log.dttm.desc()).limit(5).all() self.assertGreaterEqual(len(logs), 1) self.assertTrue(logs[0].extra) def test_action_logging_get(self): url = '/admin/airflow/graph?dag_id=example_bash_operator&execution_date={}'.format( quote_plus(self.EXAMPLE_DAG_DEFAULT_DATE.isoformat().encode('utf-8'))) self.app.get(url, follow_redirects=True) # In mysql backend, this commit() is needed to write down the logs self.session.commit() self.check_last_log("example_bash_operator", event="graph", execution_date=self.EXAMPLE_DAG_DEFAULT_DATE) def test_action_logging_post(self): form = dict( task_id="runme_1", dag_id="example_bash_operator", execution_date=self.EXAMPLE_DAG_DEFAULT_DATE.isoformat().encode('utf-8'), upstream="false", downstream="false", future="false", past="false", only_failed="false", ) self.app.post("/admin/airflow/clear", data=form) # In mysql backend, this commit() is needed to write down the logs self.session.commit() self.check_last_log("example_bash_operator", event="clear", execution_date=self.EXAMPLE_DAG_DEFAULT_DATE)
def test_lineage_backend_capture_executions(mock_emit, inlets, outlets): DEFAULT_DATE = datetime.datetime(2020, 5, 17) mock_emitter = Mock() mock_emit.return_value = mock_emitter # Using autospec on xcom_pull and xcom_push methods fails on Python 3.6. with mock.patch.dict( os.environ, { "AIRFLOW__LINEAGE__BACKEND": "datahub_provider.lineage.datahub.DatahubLineageBackend", "AIRFLOW__LINEAGE__DATAHUB_CONN_ID": datahub_rest_connection_config.conn_id, "AIRFLOW__LINEAGE__DATAHUB_KWARGS": json.dumps({ "graceful_exceptions": False, "capture_executions": True }), }, ), mock.patch("airflow.models.BaseOperator.xcom_pull"), mock.patch( "airflow.models.BaseOperator.xcom_push"), patch_airflow_connection( datahub_rest_connection_config): func = mock.Mock() func.__name__ = "foo" dag = DAG(dag_id="test_lineage_is_sent_to_backend", start_date=DEFAULT_DATE) with dag: op1 = DummyOperator( task_id="task1_upstream", inlets=inlets, outlets=outlets, ) op2 = DummyOperator( task_id="task2", inlets=inlets, outlets=outlets, ) op1 >> op2 # Airflow < 2.2 requires the execution_date parameter. Newer Airflow # versions do not require it, but will attempt to find the associated # run_id in the database if execution_date is provided. As such, we # must fake the run_id parameter for newer Airflow versions. if AIRFLOW_VERSION < packaging.version.parse("2.2.0"): ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE) # Ignoring type here because DagRun state is just a sring at Airflow 1 dag_run = DagRun( state="success", run_id=f"scheduled_{DEFAULT_DATE}") # type: ignore ti.dag_run = dag_run ti.start_date = datetime.datetime.utcnow() ti.execution_date = DEFAULT_DATE else: from airflow.utils.state import DagRunState ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}") dag_run = DagRun(state=DagRunState.SUCCESS, run_id=f"scheduled_{DEFAULT_DATE}") ti.dag_run = dag_run ti.start_date = datetime.datetime.utcnow() ti.execution_date = DEFAULT_DATE ctx1 = { "dag": dag, "task": op2, "ti": ti, "dag_run": dag_run, "task_instance": ti, "execution_date": DEFAULT_DATE, "ts": "2021-04-08T00:54:25.771575+00:00", } prep = prepare_lineage(func) prep(op2, ctx1) post = apply_lineage(func) post(op2, ctx1) # Verify that the inlets and outlets are registered and recognized by Airflow correctly, # or that our lineage backend forces it to. assert len(op2.inlets) == 1 assert len(op2.outlets) == 1 assert all(map(lambda let: isinstance(let, Dataset), op2.inlets)) assert all(map(lambda let: isinstance(let, Dataset), op2.outlets)) # Check that the right things were emitted. assert mock_emitter.emit.call_count == 17 # Running further checks based on python version because args only exists in python 3.7+ if sys.version_info[:3] > (3, 7): assert mock_emitter.method_calls[0].args[ 0].aspectName == "dataFlowInfo" assert ( mock_emitter.method_calls[0].args[0].entityUrn == "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)" ) assert mock_emitter.method_calls[1].args[ 0].aspectName == "ownership" assert ( mock_emitter.method_calls[1].args[0].entityUrn == "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)" ) assert mock_emitter.method_calls[2].args[ 0].aspectName == "globalTags" assert ( mock_emitter.method_calls[2].args[0].entityUrn == "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)" ) assert mock_emitter.method_calls[3].args[ 0].aspectName == "dataJobInfo" assert ( mock_emitter.method_calls[3].args[0].entityUrn == "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)" ) assert (mock_emitter.method_calls[4].args[0].aspectName == "dataJobInputOutput") assert ( mock_emitter.method_calls[4].args[0].entityUrn == "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)" ) assert ( mock_emitter.method_calls[4].args[0].aspect.inputDatajobs[0] == "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task1_upstream)" ) assert ( mock_emitter.method_calls[4].args[0].aspect.inputDatasets[0] == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)" ) assert ( mock_emitter.method_calls[4].args[0].aspect.outputDatasets[0] == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)" ) assert mock_emitter.method_calls[5].args[0].aspectName == "status" assert ( mock_emitter.method_calls[5].args[0].entityUrn == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)" ) assert mock_emitter.method_calls[6].args[0].aspectName == "status" assert ( mock_emitter.method_calls[6].args[0].entityUrn == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)" ) assert mock_emitter.method_calls[7].args[ 0].aspectName == "ownership" assert ( mock_emitter.method_calls[7].args[0].entityUrn == "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)" ) assert mock_emitter.method_calls[8].args[ 0].aspectName == "globalTags" assert ( mock_emitter.method_calls[8].args[0].entityUrn == "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)" ) assert (mock_emitter.method_calls[9].args[0].aspectName == "dataProcessInstanceProperties") assert ( mock_emitter.method_calls[9].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb") assert (mock_emitter.method_calls[10].args[0].aspectName == "dataProcessInstanceRelationships") assert ( mock_emitter.method_calls[10].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb") assert (mock_emitter.method_calls[11].args[0].aspectName == "dataProcessInstanceInput") assert ( mock_emitter.method_calls[11].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb") assert (mock_emitter.method_calls[12].args[0].aspectName == "dataProcessInstanceOutput") assert ( mock_emitter.method_calls[12].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb") assert mock_emitter.method_calls[13].args[0].aspectName == "status" assert ( mock_emitter.method_calls[13].args[0].entityUrn == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)" ) assert mock_emitter.method_calls[14].args[0].aspectName == "status" assert ( mock_emitter.method_calls[14].args[0].entityUrn == "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)" ) assert (mock_emitter.method_calls[15].args[0].aspectName == "dataProcessInstanceRunEvent") assert ( mock_emitter.method_calls[15].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb") assert (mock_emitter.method_calls[16].args[0].aspectName == "dataProcessInstanceRunEvent") assert ( mock_emitter.method_calls[16].args[0].entityUrn == "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")
def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None): """ Returns a dag run for the given run date, which will be matched to an existing dag run if available or create a new dag run otherwise. If the max_active_runs limit is reached, this function will return None. :param run_date: the execution date for the dag run :param dag: DAG :param session: the database session object :return: a DagRun in state RUNNING or None """ run_id = BackfillJob.ID_FORMAT_PREFIX.format(run_date.isoformat()) # consider max_active_runs but ignore when running subdags respect_dag_max_active_limit = (True if (dag.schedule_interval and not dag.is_subdag) else False) current_active_dag_count = dag.get_num_active_runs( external_trigger=False) # check if we are scheduling on top of a already existing dag_run # we could find a "scheduled" run instead of a "backfill" run = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session) if run is not None and len(run) > 0: run = run[0] if run.state == State.RUNNING: respect_dag_max_active_limit = False else: run = None # enforce max_active_runs limit for dag, special cases already # handled by respect_dag_max_active_limit if (respect_dag_max_active_limit and current_active_dag_count >= dag.max_active_runs): return None run = run or dag.create_dagrun( run_id=run_id, execution_date=run_date, start_date=timezone.utcnow(), state=State.RUNNING, external_trigger=False, session=session, conf=self.conf, ) # set required transient field run.dag = dag # explicitly mark as backfill and running run.state = State.RUNNING run.run_id = run_id run.verify_integrity(session=session) return run
def create_dagrun_from_dbnd_run( databand_run, dag, execution_date, state=State.RUNNING, external_trigger=False, conf=None, session=None, ): """ Create new DagRun and all relevant TaskInstances """ dagrun = ( session.query(DagRun) .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) .first() ) if dagrun is None: dagrun = DagRun( run_id=databand_run.run_id, execution_date=execution_date, start_date=dag.start_date, _state=state, external_trigger=external_trigger, dag_id=dag.dag_id, conf=conf, ) session.add(dagrun) else: logger.warning("Running with existing airflow dag run %s", dagrun) dagrun.dag = dag dagrun.run_id = databand_run.run_id session.commit() # create the associated task instances # state is None at the moment of creation # dagrun.verify_integrity(session=session) # fetches [TaskInstance] again # tasks_skipped = databand_run.tasks_skipped # we can find a source of the completion, but also, # sometimes we don't know the source of the "complete" TI = TaskInstance tis = ( session.query(TI) .filter(TI.dag_id == dag.dag_id, TI.execution_date == execution_date) .all() ) tis = {ti.task_id: ti for ti in tis} for af_task in dag.tasks: ti = tis.get(af_task.task_id) if ti is None: ti = TaskInstance(af_task, execution_date=execution_date) ti.start_date = timezone.utcnow() ti.end_date = timezone.utcnow() session.add(ti) task_run = databand_run.get_task_run_by_af_id(af_task.task_id) # all tasks part of the backfill are scheduled to dagrun # Set log file path to expected airflow log file path task_run.log.local_log_file.path = ti.log_filepath.replace( ".log", "/{0}.log".format(ti.try_number) ) if task_run.is_reused: # this task is completed and we don't need to run it anymore ti.state = State.SUCCESS session.commit() return dagrun
def execute(self, context): started_at = datetime.utcnow() _keep_going = True while _keep_going: _force_run_data = self.get_force_run_data() _logger.info("Force run data: {}".format(_force_run_data)) if not _force_run_data: if (datetime.utcnow() - started_at).total_seconds() > self.timeout: raise AirflowSkipException('Snap. Time is OUT.') sleep(self.poke_interval) continue for row in _force_run_data: _keep_going = False biowardrobe_uid = row['uid'] # TODO: Check if dag is running in airflow # TODO: If not running! data = self.get_record_data(biowardrobe_uid) if not data: _logger.error( 'No biowardrobe data {}'.format(biowardrobe_uid)) continue # # Actual Force RUN basedir = data['output_folder'] try: os.chdir(basedir) for root, dirs, files in os.walk(".", topdown=False): for name in files: if "fastq" in name: continue os.remove(os.path.join(root, name)) rmtree(os.path.join(basedir, 'tophat'), True) except: pass if int(data['deleted']) == 0: cmd = 'bunzip2 {}*.fastq.bz2'.format(biowardrobe_uid) try: check_output(cmd, shell=True) except Exception as e: _logger.error("Can't uncompress: {} {}".format( cmd, str(e))) if not os.path.isfile(biowardrobe_uid + '.fastq'): _logger.error( "File does not exist: {}".format(biowardrobe_uid)) continue if not os.path.isfile(biowardrobe_uid + '_2.fastq') and data['pair']: _logger.error("File 2 does not exist: {}".format( biowardrobe_uid)) continue else: rmtree(basedir, True) mysql = MySqlHook(mysql_conn_id=biowardrobe_connection_id) with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: self.drop_sql(cursor, data) if int(data['deleted']) == 0: cursor.execute( "update labdata set libstatustxt=%s, libstatus=10, forcerun=0, tagstotal=0," "tagsmapped=0,tagsribo=0,tagsused=0,tagssuppressed=0 where uid=%s", ("Ready to be reanalyzed", biowardrobe_uid)) conn.commit() else: cursor.execute( "update labdata set libstatustxt=%s,deleted=2,datedel=CURDATE() where uid=%s", ("Deleted", biowardrobe_uid)) conn.commit() _logger.info("Deleted: {}".format(biowardrobe_uid)) continue _dag_id = os.path.basename( os.path.splitext(data['workflow'])[0]) _run_id = 'forcerun__{}__{}'.format(biowardrobe_uid, uuid.uuid4()) session = settings.Session() dr = DagRun(dag_id=_dag_id, run_id=_run_id, conf={ 'biowardrobe_uid': biowardrobe_uid, 'run_id': _run_id }, execution_date=datetime.now(), start_date=datetime.now(), external_trigger=True) logging.info("Creating DagRun {}".format(dr)) session.add(dr) session.commit() session.close()
def dag_backfill(args, dag=None): """Creates backfill job or dry run for a DAG""" logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) signal.signal(signal.SIGTERM, sigint_handler) import warnings warnings.warn( '--ignore-first-depends-on-past is deprecated as the value is always set to True', category=PendingDeprecationWarning, ) if args.ignore_first_depends_on_past is False: args.ignore_first_depends_on_past = True if not args.start_date and not args.end_date: raise AirflowException("Provide a start_date and/or end_date") dag = dag or get_dag(args.subdir, args.dag_id) # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.partial_subset(task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies) if not dag.task_dict: raise AirflowException( f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..." ) run_conf = None if args.conf: run_conf = json.loads(args.conf) if args.dry_run: print(f"Dry run of DAG {args.dag_id} on {args.start_date}") dr = DagRun(dag.dag_id, execution_date=args.start_date) for task in dag.tasks: print(f"Task {task.task_id}") ti = TaskInstance(task, run_id=None) ti.dag_run = dr ti.dry_run() else: if args.reset_dagruns: DAG.clear_dags( [dag], start_date=args.start_date, end_date=args.end_date, confirm_prompt=not args.yes, include_subdags=True, dag_run_state=DagRunState.QUEUED, ) try: dag.run( start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, local=args.local, donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), ignore_first_depends_on_past=args.ignore_first_depends_on_past, ignore_task_deps=args.ignore_dependencies, pool=args.pool, delay_on_limit_secs=args.delay_on_limit, verbose=args.verbose, conf=run_conf, rerun_failed_tasks=args.rerun_failed_tasks, run_backwards=args.run_backwards, continue_on_failures=args.continue_on_failures, ) except ValueError as vr: print(str(vr)) sys.exit(1)
def test_should_respond_200_with_tilde_and_access_to_all_dags(self): dag_id_1 = 'test-dag-id-1' task_id_1 = 'test-task-id-1' execution_date = '2005-04-02T00:00:00+00:00' execution_date_parsed = parse_execution_date(execution_date) dag_run_id_1 = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) dag_id_2 = 'test-dag-id-2' task_id_2 = 'test-task-id-2' dag_run_id_2 = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) self._create_xcom_entries(dag_id_2, dag_run_id_2, execution_date_parsed, task_id_2) response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", environ_overrides={'REMOTE_USER': "******"}, ) self.assertEqual(200, response.status_code) response_data = response.json for xcom_entry in response_data['xcom_entries']: xcom_entry['timestamp'] = "TIMESTAMP" self.assertEqual( response.json, { 'xcom_entries': [ { 'dag_id': dag_id_1, 'execution_date': execution_date, 'key': 'test-xcom-key-1', 'task_id': task_id_1, 'timestamp': "TIMESTAMP", }, { 'dag_id': dag_id_1, 'execution_date': execution_date, 'key': 'test-xcom-key-2', 'task_id': task_id_1, 'timestamp': "TIMESTAMP", }, { 'dag_id': dag_id_2, 'execution_date': execution_date, 'key': 'test-xcom-key-1', 'task_id': task_id_2, 'timestamp': "TIMESTAMP", }, { 'dag_id': dag_id_2, 'execution_date': execution_date, 'key': 'test-xcom-key-2', 'task_id': task_id_2, 'timestamp': "TIMESTAMP", }, ], 'total_entries': 4, }, )
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert isinstance(execution_date, datetime.datetime) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns( current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state)) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state)) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.close() return tis_altered
def test_backfill_max_limit_check(self): dag_id = 'test_backfill_max_limit_check' run_id = 'test_dagrun' start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE dag_run_created_cond = threading.Condition() def run_backfill(cond): cond.acquire() # this session object is different than the one in the main thread with create_session() as thread_session: try: dag = self._get_dag_test_max_active_limits(dag_id) # Existing dagrun that is not within the backfill range dag.create_dagrun( run_id=run_id, state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1), start_date=DEFAULT_DATE, ) thread_session.commit() cond.notify() finally: cond.release() thread_session.close() executor = TestExecutor() job = BackfillJob(dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True) job.run() backfill_job_thread = threading.Thread(target=run_backfill, name="run_backfill", args=(dag_run_created_cond,)) dag_run_created_cond.acquire() with create_session() as session: backfill_job_thread.start() try: # at this point backfill can't run since the max_active_runs has been # reached, so it is waiting dag_run_created_cond.wait(timeout=1.5) dagruns = DagRun.find(dag_id=dag_id) dr = dagruns[0] self.assertEqual(1, len(dagruns)) self.assertEqual(dr.run_id, run_id) # allow the backfill to execute # by setting the existing dag run to SUCCESS, # backfill will execute dag runs 1 by 1 dr.set_state(State.SUCCESS) session.merge(dr) session.commit() backfill_job_thread.join() dagruns = DagRun.find(dag_id=dag_id) self.assertEqual(3, len(dagruns)) # 2 from backfill + 1 existing self.assertEqual(dagruns[-1].run_id, dr.run_id) finally: dag_run_created_cond.release()
def kill_running_tasks(self): """Stop running the specified task instance and downstream tasks. Obtain task_instance from session according to dag_id, run_id and task_id, If task_id is not empty, get task_instance with RUNNIN or NONE status from dag_run according to task_id, and set task_instance status to FAILED. If task_id is empty, get all task_instances whose status is RUNNIN or NONE from dag_run, and set the status of these task_instances to FAILED. args: dag_id: dag id run_id: the run id of dag run task_id: the task id of task instance of dag """ logging.info("Executing custom 'kill_running_tasks' function") dagbag = self.get_dagbag() dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') task_id = self.get_argument(request, 'task_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") if dag_id not in dagbag.dags: return ApiResponse.bad_request("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) logging.info('dag: ' + str(dag)) logging.info('dag_subdag: ' + str(dag.subdags)) tis = [] if task_id: task_instance = DagRun.get_task_instance(dag_run, task_id) if task_instance is None or task_instance.state not in [State.RUNNING, State.NONE]: return ApiResponse.not_found("task is not found or state is neither RUNNING nor NONE") else: tis.append(task_instance) else: tis = DagRun.get_task_instances(dag_run, [State.RUNNING, State.NONE]) logging.info('tis: ' + str(tis)) running_task_count = len(tis) if running_task_count > 0: for ti in tis: ti.state = State.FAILED ti.end_date = timezone.utcnow() session.merge(ti) session.commit() else: return ApiResponse.not_found("dagRun don't have running tasks") session.close() return ApiResponse.success()
def execute(self, context: Dict): if isinstance(self.execution_date, datetime.datetime): execution_date = self.execution_date elif isinstance(self.execution_date, str): execution_date = timezone.parse(self.execution_date) self.execution_date = execution_date else: execution_date = timezone.utcnow() run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date) try: # Ignore MyPy type for self.execution_date # because it doesn't pick up the timezone.parse() for strings dag_run = trigger_dag( dag_id=self.trigger_dag_id, run_id=run_id, conf=self.conf, execution_date=self.execution_date, replace_microseconds=False, ) except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, self.execution_date) # Get target dag object and call clear() dag_model = DagModel.get_current(self.trigger_dag_id) if dag_model is None: raise DagNotFound( f"Dag id {self.trigger_dag_id} not found in DagModel") dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) dag.clear(start_date=self.execution_date, end_date=self.execution_date) else: raise e if self.wait_for_completion: # wait for dag to complete while True: self.log.info( 'Waiting for %s on %s to become allowed state %s ...', self.trigger_dag_id, dag_run.execution_date, self.allowed_states, ) time.sleep(self.poke_interval) dag_run.refresh_from_db() state = dag_run.state if state in self.failed_states: raise AirflowException( f"{self.trigger_dag_id} failed with failed states {state}" ) if state in self.allowed_states: self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return
def _trigger_dag( dag_id: str, dag_bag: DagBag, dag_run: DagModel, run_id: Optional[str], conf: Optional[Union[dict, str]], execution_date: Optional[datetime], replace_microseconds: bool, ) -> List[DagRun]: # pylint: disable=too-many-arguments """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param dag_run: DAG Run model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized if dag_id not in dag_bag.dags: raise DagNotFound("Dag id {} not found".format(dag_id)) execution_date = execution_date if execution_date else timezone.utcnow() if not timezone.is_localized(execution_date): raise ValueError("The execution_date should be localized") if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if dag.default_args and 'start_date' in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( "The execution_date [{0}] should be >= start_date [{1}] from DAG's default_args" .format(execution_date.isoformat(), min_dag_start_date.isoformat())) run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) dag_run = dag_run.find(dag_id=dag_id, run_id=run_id) if dag_run: raise DagRunAlreadyExists( f"Run id {dag_run.run_id} already exists for dag id {dag_id}") run_conf = None if conf: run_conf = conf if isinstance(conf, dict) else json.loads(conf) triggers = [] dags_to_trigger = [dag] + dag.subdags for _dag in dags_to_trigger: trigger = _dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True, ) triggers.append(trigger) return triggers
def _trigger_dag( dag_id: str, dag_bag: DagBag, run_id: Optional[str] = None, conf: Optional[Union[dict, str]] = None, execution_date: Optional[datetime] = None, replace_microseconds: bool = True, ) -> List[Optional[DagRun]]: """Triggers DAG run. :param dag_id: DAG ID :param dag_bag: DAG Bag model :param run_id: ID of the dag_run :param conf: configuration :param execution_date: date of execution :param replace_microseconds: whether microseconds should be zeroed :return: list of triggered dags """ dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized if dag is None or dag_id not in dag_bag.dags: raise DagNotFound(f"Dag id {dag_id} not found") execution_date = execution_date if execution_date else timezone.utcnow() if not timezone.is_localized(execution_date): raise ValueError("The execution_date should be localized") if replace_microseconds: execution_date = execution_date.replace(microsecond=0) if dag.default_args and 'start_date' in dag.default_args: min_dag_start_date = dag.default_args["start_date"] if min_dag_start_date and execution_date < min_dag_start_date: raise ValueError( f"The execution_date [{execution_date.isoformat()}] should be >= start_date " f"[{min_dag_start_date.isoformat()}] from DAG's default_args" ) logical_date = timezone.coerce_datetime(execution_date) data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) run_id = run_id or dag.timetable.generate_run_id( run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval ) dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) if dag_run: raise DagRunAlreadyExists( f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" ) run_conf = None if conf: run_conf = conf if isinstance(conf, dict) else json.loads(conf) dag_runs = [] dags_to_run = [dag] + dag.subdags for _dag in dags_to_run: dag_run = _dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=DagRunState.QUEUED, conf=run_conf, external_trigger=True, dag_hash=dag_bag.dags_hash.get(dag_id), data_interval=data_interval, ) dag_runs.append(dag_run) return dag_runs
def run_task_instance(self): """Run some tasks, other tasks do not run According to dag_id, run_id get dag_run from session, Obtain the task instances that need to be run according to tasks, Define the status of these task instances as None, and define the status of other task instances that do not need to run as SUCCESS args: dag_id: dag id run_id: the run id of dag run tasks: the task id of task instance of dag, Multiple task ids are split by ',' conf: define dynamic configuration in dag """ logging.info("Executing custom 'run_task_instance' function") dagbag = self.get_dagbag() dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') tasks = self.get_argument(request, 'tasks') conf = self.get_argument(request, 'conf') run_conf = None if conf: try: run_conf = json.loads(conf) except ValueError: return ApiResponse.error('Failed', 'Invalid JSON configuration') dr = DagRun.find(dag_id=dag_id, run_id=run_id) if dr: return ApiResponse.not_found('run_id {} already exists'.format(run_id)) logging.info('tasks: ' + str(tasks)) task_list = tasks.split(',') session = settings.Session() if dag_id not in dagbag.dags: return ApiResponse.not_found("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) logging.info('dag: ' + str(dag)) for task_id in task_list: try: task = dag.get_task(task_id) except TaskNotFound: return ApiResponse.not_found("dag task of {} is not found".format(str(task_id))) logging.info('task:' + str(task)) execution_date = timezone.utcnow() dag_run = dag.create_dagrun( run_id=run_id, execution_date=execution_date, state=State.RUNNING, conf=run_conf, external_trigger=True ) tis = dag_run.get_task_instances() for ti in tis: if ti.task_id in task_list: ti.state = None else: ti.state = State.SUCCESS session.merge(ti) session.commit() session.close() return ApiResponse.success({ "execution_date": (execution_date.strftime("%Y-%m-%dT%H:%M:%S.%f%z")) })
def get_most_recent_dag_run(dag_id): dag_runs = DagRun.find(dag_id=dag_id) dag_runs.sort(key=lambda x: x.execution_date, reverse=True) print(dag_runs) return dag_runs[0] if dag_runs else None
def _get_dagrun(self, execution_date): dag_runs = DagRun.find( dag_id=self.subdag.dag_id, execution_date=execution_date, ) return dag_runs[0] if dag_runs else None
def setup_method(self): self.dag_id = 'test-dag-id' self.task_id = 'test-task-id' self.execution_date = '2005-04-02T00:00:00+00:00' self.execution_date_parsed = parse_execution_date(self.execution_date) self.dag_run_id = DR.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed)
def execute(self, context: Context): if isinstance(self.execution_date, datetime.datetime): parsed_execution_date = self.execution_date elif isinstance(self.execution_date, str): parsed_execution_date = timezone.parse(self.execution_date) else: parsed_execution_date = timezone.utcnow() if self.trigger_run_id: run_id = self.trigger_run_id else: run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date) try: dag_run = trigger_dag( dag_id=self.trigger_dag_id, run_id=run_id, conf=self.conf, execution_date=parsed_execution_date, replace_microseconds=False, ) except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_execution_date) # Get target dag object and call clear() dag_model = DagModel.get_current(self.trigger_dag_id) if dag_model is None: raise DagNotFound( f"Dag id {self.trigger_dag_id} not found in DagModel") dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) dag.clear(start_date=parsed_execution_date, end_date=parsed_execution_date) dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0] else: raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") # Store the execution date from the dag run (either created or found above) to # be used when creating the extra link on the webserver. ti = context['task_instance'] ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) if self.wait_for_completion: # wait for dag to complete while True: self.log.info( 'Waiting for %s on %s to become allowed state %s ...', self.trigger_dag_id, dag_run.execution_date, self.allowed_states, ) time.sleep(self.poke_interval) dag_run.refresh_from_db() state = dag_run.state if state in self.failed_states: raise AirflowException( f"{self.trigger_dag_id} failed with failed states {state}" ) if state in self.allowed_states: self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return
def test_backfill_fill_blanks(self): dag = DAG( 'test_backfill_fill_blanks', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}, ) with dag: op1 = DummyOperator(task_id='op1') op2 = DummyOperator(task_id='op2') op3 = DummyOperator(task_id='op3') op4 = DummyOperator(task_id='op4') op5 = DummyOperator(task_id='op5') op6 = DummyOperator(task_id='op6') dag.clear() dr = dag.create_dagrun(run_id='test', state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE) executor = TestExecutor() session = settings.Session() tis = dr.get_task_instances() for ti in tis: if ti.task_id == op1.task_id: ti.state = State.UP_FOR_RETRY ti.end_date = DEFAULT_DATE elif ti.task_id == op2.task_id: ti.state = State.FAILED elif ti.task_id == op3.task_id: ti.state = State.SKIPPED elif ti.task_id == op4.task_id: ti.state = State.SCHEDULED elif ti.task_id == op5.task_id: ti.state = State.UPSTREAM_FAILED # op6 = None session.merge(ti) session.commit() session.close() job = BackfillJob(dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, executor=executor) self.assertRaisesRegex( AirflowException, 'Some task instances failed', job.run) self.assertRaises(sqlalchemy.orm.exc.NoResultFound, dr.refresh_from_db) # the run_id should have changed, so a refresh won't work drs = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE) dr = drs[0] self.assertEqual(dr.state, State.FAILED) tis = dr.get_task_instances() for ti in tis: if ti.task_id in (op1.task_id, op4.task_id, op6.task_id): self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == op2.task_id: self.assertEqual(ti.state, State.FAILED) elif ti.task_id == op3.task_id: self.assertEqual(ti.state, State.SKIPPED) elif ti.task_id == op5.task_id: self.assertEqual(ti.state, State.UPSTREAM_FAILED)
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert timezone.is_localized(execution_date) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns(current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter( TI.dag_id==dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state) ) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state) ) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.expunge_all() session.close() return tis_altered
def restart_failed_task(self): """Restart the failed task in the specified dag run. According to dag_id, run_id get dag_run from session, query task_instances that status is FAILED in dag_run, restart them and clear status of all task_instance's downstream of them. args: dag_id: dag id run_id: the run id of dag run """ logging.info("Executing custom 'restart_failed_task' function") dagbag = self.get_dagbag() dag_id = self.get_argument(request, 'dag_id') run_id = self.get_argument(request, 'run_id') session = settings.Session() query = session.query(DagRun) dag_run = query.filter( DagRun.dag_id == dag_id, DagRun.run_id == run_id ).first() if dag_run is None: return ApiResponse.not_found("dag run is not found") if dag_id not in dagbag.dags: return ApiResponse.bad_request("Dag id {} not found".format(dag_id)) dag = dagbag.get_dag(dag_id) if dag is None: return ApiResponse.not_found("dag is not found") tis = DagRun.get_task_instances(dag_run, State.FAILED) logging.info('task_instances: ' + str(tis)) failed_task_count = len(tis) if failed_task_count > 0: for ti in tis: dag = DAG.sub_dag( self=dag, task_regex=r"^{0}$".format(ti.task_id), include_downstream=True, include_upstream=False) count = DAG.clear( self=dag, start_date=dag_run.execution_date, end_date=dag_run.execution_date, ) logging.info('count:' + str(count)) else: return ApiResponse.not_found("dagRun don't have failed tasks") session.close() return ApiResponse.success({ 'failed_task_count': failed_task_count, 'clear_task_count': count })