def get_link(self, operator, dttm): task_instance = TaskInstance(task=operator, execution_date=dttm) gcp_metadata_dict = task_instance.xcom_pull(task_ids=operator.task_id, key="gcp_metadata") if not gcp_metadata_dict: return '' job_id = gcp_metadata_dict['job_id'] project_id = gcp_metadata_dict['project_id'] console_link = f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}" return console_link
def test_mark_success_on_success_callback(self): """ Test that ensures that where a task is marked suceess in the UI on_success_callback gets executed """ data = {'called': False} def success_callback(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_success') data['called'] = True dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback) session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner job1.task_runner = StandardTaskRunner(job1) process = multiprocessing.Process(target=job1.run) process.start() ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break time.sleep(0.1) ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) ti.state = State.SUCCESS session.merge(ti) session.commit() job1.heartbeat_callback(session=None) self.assertTrue(data['called']) process.join(timeout=10) self.assertFalse(process.is_alive())
def test_mark_failure_on_failure_callback(self): """ Test that ensures that mark_failure in the UI fails the task, and executes on_failure_callback """ data = {'called': False} def check_failure(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_failure') data['called'] = True def task_function(ti): print("python_callable run in pid %s", os.getpid()) with create_session() as session: self.assertEqual(State.RUNNING, ti.state) ti.log.info("Marking TI as failed 'externally'") ti.state = State.FAILED session.merge(ti) session.commit() time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed data['reached_end_of_sleep'] = True with DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE) as dag: task = PythonOperator(task_id='test_state_succeeded1', python_callable=task_function, on_failure_callback=check_failure) session = settings.Session() dag.clear() dag.create_dagrun(run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) with timeout(30): # This should be _much_ shorter to run. # If you change this limit, make the timeout in the callbable above bigger job1.run() ti.refresh_from_db() self.assertEqual(ti.state, State.FAILED) self.assertTrue(data['called']) self.assertNotIn( 'reached_end_of_sleep', data, 'Task should not have been allowed to run to completion')
def test_init_with_template_cluster_label(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) task = QuboleOperator(task_id=TASK_ID, dag=dag, cluster_label='{{ params.cluster_label }}', params={'cluster_label': 'default'}) ti = TaskInstance(task, DEFAULT_DATE) ti.render_templates() self.assertEqual(task.cluster_label, 'default')
def test_get_dated_main_runner_handles_zero_shift(): dag = DAG(dag_id='test_dag', start_date=datetime.strptime('2019-01-01', '%Y-%m-%d')) execution_date = datetime.strptime('2019-01-01', '%Y-%m-%d').replace(tzinfo=timezone.utc) main_func = PickleMock() runner = op_util.get_dated_main_runner_operator(dag, main_func, timedelta(minutes=1)) ti = TaskInstance(runner, execution_date) ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) main_func.assert_called_with('2019-01-01')
def test_try_adopt_task_instances(self): exec_date = timezone.utcnow() - timedelta(minutes=2) start_date = timezone.utcnow() - timedelta(days=2) queued_dttm = timezone.utcnow() - timedelta(minutes=1) try_number = 1 with DAG("test_try_adopt_task_instances_none") as dag: task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) ti1 = TaskInstance(task=task_1, execution_date=exec_date) ti1.external_executor_id = '231' ti1.queued_dttm = queued_dttm ti2 = TaskInstance(task=task_2, execution_date=exec_date) ti2.external_executor_id = '232' ti2.queued_dttm = queued_dttm tis = [ti1, ti2] executor = celery_executor.CeleryExecutor() self.assertEqual(executor.running, set()) self.assertEqual(executor.adopted_task_timeouts, {}) self.assertEqual(executor.tasks, {}) not_adopted_tis = executor.try_adopt_task_instances(tis) key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) self.assertEqual(executor.running, {key_1, key_2}) self.assertEqual( dict(executor.adopted_task_timeouts), { key_1: queued_dttm + executor.task_adoption_timeout, key_2: queued_dttm + executor.task_adoption_timeout }) self.assertEqual(executor.tasks, { key_1: AsyncResult("231"), key_2: AsyncResult("232") }) self.assertEqual(not_adopted_tis, [])
def verify_integrity(self, session=None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. """ from airflow.models.taskinstance import TaskInstance # Avoid circular import dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = [] for ti in tis: task_ids.append(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning("Failed to get task '{}' for dag '{}'. " "Marking it as removed.".format(ti, dag)) Stats.incr( "task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED is_task_in_dag = task is not None should_restore_task = is_task_in_dag and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '{}' which was previously " "removed from DAG '{}'".format(ti, dag)) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE # check for missing tasks for task in six.itervalues(dag.task_dict): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr( "task_instance_created-{}".format(task.__class__.__name__), 1, 1) # add TaskState to db ti = TaskInstance(task, self.execution_date) ts = TaskState(ti) if task.event_met_handler() is not None: ts.event_handler = task.event_met_handler() session.add(ti) session.add(ts) session.commit()
def test_try_adopt_task_instances_none(self): date = datetime.utcnow() start_date = datetime.utcnow() - timedelta(days=2) with DAG("test_try_adopt_task_instances_none"): task_1 = BaseOperator(task_id="task_1", start_date=start_date) key1 = TaskInstance(task=task_1, execution_date=date) tis = [key1] executor = celery_executor.CeleryExecutor() self.assertEqual(executor.try_adopt_task_instances(tis), tis)
def test_localtaskjob_maintain_heart_rate(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) # this should make sure we only heartbeat once and exit at the second # loop in _execute() return_codes = [None, 0] def multi_return_code(): return return_codes.pop(0) time_start = time.time() from airflow.task.task_runner.standard_task_runner import StandardTaskRunner with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start: with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code: mock_ret_code.side_effect = multi_return_code job1.run() self.assertEqual(mock_start.call_count, 1) self.assertEqual(mock_ret_code.call_count, 2) time_end = time.time() self.assertEqual(self.mock_base_job_sleep.call_count, 1) self.assertEqual(job1.state, State.SUCCESS) # Consider we have patched sleep call, it should not be sleeping to # keep up with the heart rate in other unpatched places # # We already make sure patched sleep call is only called once self.assertLess(time_end - time_start, job1.heartrate) session.close()
def test_should_continue_with_cp(load_dag): dag_bag = load_dag('bq_to_wrench') dag = dag_bag.get_dag('bq_to_wrench') table = 'staging.users' task = dag.get_task(f'continue_if_data_{table}') assert isinstance(task, BranchPythonOperator) ti = TaskInstance(task=task, execution_date=datetime.now()) XCom.set(key=table, value={'has_data': True}, task_id=task.task_id, dag_id=dag.dag_id, execution_date=ti.execution_date) task.execute(ti.get_template_context())
def test_sets_initial_checkpoint(load_dag, env, bigquery_helper): # Remove all checkpoints for table table = 'staging.users' bigquery_helper.query( f"DELETE FROM `{env['project']}.system.checkpoint` WHERE table = '{table}'" ) # Execute get checkpoint task. I expect it to create an initial checkpoint. dag_bag = load_dag('bq_to_wrench') dag = dag_bag.get_dag('bq_to_wrench') task = dag.get_task(f'get_checkpoint_{table}') assert isinstance(task, GetCheckpointOperator) ti = TaskInstance(task=task, execution_date=datetime.now()) task.execute(ti.get_template_context())
def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes): unique_prefix = str(uuid.uuid4()) dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE) task = DummyOperator(task_id='test_state_succeeded1', dag=dag) dag.clear() dag.create_dagrun(run_id=unique_prefix, state=State.NONE) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) with assert_queries_count(13): job.run()
def get_link(self, operator, dttm): """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id']) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') url = host + str(qds_command_id) if qds_command_id else '' return url
def run(self, start_date=None, end_date=None, ignore_first_depends_on_past=False, ignore_ti_state=False, mark_success=False): """ Run a set of task instances for a date range. """ start_date = start_date or self.start_date end_date = end_date or self.end_date or timezone.utcnow() for dt in self.dag.date_range(start_date, end_date=end_date): TaskInstance(self, dt).run( mark_success=mark_success, ignore_depends_on_past=(dt == start_date and ignore_first_depends_on_past), ignore_ti_state=ignore_ti_state)
def test_heartbeat_failed_fast(self): """ Test that task heartbeat will sleep when it fails fast """ self.mock_base_job_sleep.side_effect = time.sleep with create_session() as session: dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag_id = 'test_heartbeat_failed_fast' task_id = 'test_heartbeat_failed_fast_op' dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) dag.create_dagrun( run_id="test_heartbeat_failed_fast_run", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.commit() job = LocalTaskJob(task_instance=ti, executor=MockExecutor(do_update=False)) job.heartrate = 2 heartbeat_records = [] job.heartbeat_callback = lambda session: heartbeat_records.append( job.latest_heartbeat) job._execute() self.assertGreater(len(heartbeat_records), 2) for i in range(1, len(heartbeat_records)): time1 = heartbeat_records[i - 1] time2 = heartbeat_records[i] # Assert that difference small enough delta = (time2 - time1).total_seconds() self.assertAlmostEqual(delta, job.heartrate, delta=0.05)
def run(self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ignore_first_depends_on_past: bool = False, ignore_ti_state: bool = False, mark_success: bool = False) -> None: """ Run a set of task instances for a date range. """ start_date = start_date or self.start_date end_date = end_date or self.end_date or timezone.utcnow() for execution_date in self.dag.date_range(start_date, end_date=end_date): TaskInstance(self, execution_date).run( mark_success=mark_success, ignore_depends_on_past=(execution_date == start_date and ignore_first_depends_on_past), ignore_ti_state=ignore_ti_state)
def test_PyIdempaOp_idempotent(self): self.dag = DAG( TEST_DAG_ID, schedule_interval="@daily", default_args={"start_date": datetime.now()}, ) with TemporaryDirectory() as tempdir: output_path = f"{tempdir}/test_file.txt" self.assertFalse( os.path.exists(output_path)) # ensure doesn't already exist self.op = PythonIdempatomicFileOperator( dag=self.dag, task_id="test", output_pattern=output_path, python_callable=self.f, ) self.ti = TaskInstance(task=self.op, execution_date=datetime.now()) result = self.op.execute(self.ti.get_template_context()) self.assertEqual(result, output_path) self.assertFalse(self.op.previously_completed) self.assertTrue(os.path.exists(output_path)) with open(output_path, "r") as fout: self.assertEqual(fout.read(), "test") # now run task again result = self.op.execute(self.ti.get_template_context()) self.assertEqual(result, output_path) # result will still give path self.assertTrue(self.op.previously_completed) # if function had run again, it would now be 'testtest' with open(output_path, "r") as fout: self.assertEqual(fout.read(), "test") # run function again to ensure 'testtest' is written to file upon second call self.f(output_path) with open(output_path, "r") as fout: self.assertEqual(fout.read(), "testtest")
def test_mark_success_no_kill(self): """ Test that ensures that mark_success in the UI doesn't cause the task to fail, and that the task exits """ dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_mark_success') task = dag.get_task('task1') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) process = multiprocessing.Process(target=job1.run) process.start() ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break time.sleep(0.1) ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) ti.state = State.SUCCESS session.merge(ti) session.commit() process.join(timeout=10) self.assertFalse(process.is_alive()) ti.refresh_from_db() self.assertEqual(State.SUCCESS, ti.state)
def test_get_redirect_url(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) with dag: task = QuboleOperator(task_id=TASK_ID, qubole_conn_id=TEST_CONN, command_type='shellcmd', parameters="param1 param2", dag=dag) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.xcom_push('qbol_cmd_id', 12345) # check for positive case url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS') self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345') # check for negative case url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') self.assertEqual(url2, '')
def skip(self, dag_run, execution_date, tasks, session=None): """ Sets tasks instances to skipped from the same dag run. :param dag_run: the DagRun for which to set the tasks to skipped :param execution_date: execution_date :param tasks: tasks to skip (not task_ids) :param session: db session to use """ if not tasks: return task_ids = [d.task_id for d in tasks] now = timezone.utcnow() if dag_run: session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, TaskInstance.task_id.in_(task_ids)).update( { TaskInstance.state: State.SKIPPED, TaskInstance.start_date: now, TaskInstance.end_date: now }, synchronize_session=False) session.commit() else: if execution_date is None: raise ValueError("Execution date is None and no dag run") self.log.warning("No DAG RUN present this should not happen") # this is defensive against dag runs that are not complete for task in tasks: ti = TaskInstance(task, execution_date=execution_date) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) session.commit()
def test_localtaskjob_double_trigger(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dr = dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method: job1.run() mock_method.assert_not_called() ti = dr.get_task_instance(task_id=task.task_id, session=session) self.assertEqual(ti.pid, 1) self.assertEqual(ti.state, State.RUNNING) session.close()
def test_error_sending_task(self): def fake_execute_command(): pass with _prepare_app(execute=fake_execute_command): # fake_execute_command takes no arguments while execute_command takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() task = BashOperator( task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now() ) when = datetime.now() value_tuple = 'command', 1, None, \ SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple executor.heartbeat() self.assertEqual(0, len(executor.queued_tasks), "Task should no longer be queued") self.assertEqual(executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0], State.FAILED)
def _kill_zombies(self, dag, zombies, session): """ copy paste from airflow.models.dagbag.DagBag.kill_zombies """ from airflow.models.taskinstance import TaskInstance # Avoid circular import for zombie in zombies: if zombie.task_id in dag.task_ids: task = dag.get_task(zombie.task_id) ti = TaskInstance(task, zombie.execution_date) # Get properties needed for failure handling from SimpleTaskInstance. ti.start_date = zombie.start_date ti.end_date = zombie.end_date ti.try_number = zombie.try_number ti.state = zombie.state # ti.test_mode = self.UNIT_TEST_MODE ti.handle_failure( "{} detected as zombie".format(ti), ti.test_mode, ti.get_template_context(), ) self.log.info("Marked zombie job %s as %s", ti, ti.state) session.commit()
def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["TaskInstance"]: """Create the mapped task instances for mapped task. :return: The mapped task instances, in ascending order by map index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook total_length = functools.reduce( operator.mul, self._get_map_lengths(run_id, session=session).values()) state: Optional[TaskInstanceState] = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index == -1, or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ).one_or_none()) ret: List[TaskInstance] = [] if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length < 1: # If the upstream maps this to a zero-length value, simply marked the # unmapped task instance as SKIPPED (if needed). self.log.info( "Marking %s as SKIPPED since the map has %d values to expand", unmapped_ti, total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED session.flush() return ret # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 state = unmapped_ti.state self.log.debug("Updated in place to become %s", unmapped_ti) ret.append(unmapped_ti) indexes_to_map = range(1, total_length) else: # Only create "missing" ones. current_max_mapping = (session.query( func.max(TaskInstance.map_index)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, ).scalar()) indexes_to_map = range(current_max_mapping + 1, total_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) # type: ignore self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) ti.task = self ret.append(ti) # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index >= total_length, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() return ret
def expand_mapped_task( self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]: """Create the mapped task instances for mapped task. :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the maximum map_index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook total_length: Optional[int] try: total_length = self._get_specified_expand_input( ).get_total_map_length(run_id, session=session) except NotFullyPopulated as e: self.log.info( "Cannot expand %r for run %s; missing upstream values: %s", self, run_id, sorted(e.missing), ) total_length = None state: Optional[TaskInstanceState] = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index == -1, or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ).one_or_none()) all_expanded_tis: List[TaskInstance] = [] if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length is None: # If the map length cannot be calculated (due to unavailable # upstream sources), fail the unmapped task. unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED indexes_to_map: Iterable[int] = () elif total_length < 1: # If the upstream maps this to a zero-length value, simply mark # the unmapped task instance as SKIPPED (if needed). self.log.info( "Marking %s as SKIPPED since the map has %d values to expand", unmapped_ti, total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED indexes_to_map = () else: # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 self.log.debug("Updated in place to become %s", unmapped_ti) all_expanded_tis.append(unmapped_ti) indexes_to_map = range(1, total_length) state = unmapped_ti.state elif not total_length: # Nothing to fixup. indexes_to_map = () else: # Only create "missing" ones. current_max_mapping = (session.query( func.max(TaskInstance.map_index)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, ).scalar()) indexes_to_map = range(current_max_mapping + 1, total_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) ti.refresh_from_task( self) # session.merge() loses task information. all_expanded_tis.append(ti) # Coerce the None case to 0 -- these two are almost treated identically, # except the unmapped ti (if exists) is marked to different states. total_expanded_ti_count = total_length or 0 # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index >= total_expanded_ti_count, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() return all_expanded_tis, total_expanded_ti_count - 1
def test_mark_success_on_success_callback(self): """ Test that ensures that where a task is marked suceess in the UI on_success_callback gets executed """ # use shared memory value so we can properly track value change even if # it's been updated across processes. success_callback_called = Value('i', 0) task_terminated_externally = Value('i', 1) shared_mem_lock = Lock() def success_callback(context): with shared_mem_lock: success_callback_called.value += 1 assert context['dag_run'].dag_id == 'test_mark_success' dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) def task_function(ti): # pylint: disable=unused-argument time.sleep(60) # This should not happen -- the state change should be noticed and the task should get killed with shared_mem_lock: task_terminated_externally.value = 0 task = PythonOperator( task_id='test_state_succeeded1', python_callable=task_function, on_success_callback=success_callback, dag=dag, ) session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) job1.task_runner = StandardTaskRunner(job1) settings.engine.dispose() process = multiprocessing.Process(target=job1.run) process.start() for _ in range(0, 25): ti.refresh_from_db() if ti.state == State.RUNNING: break time.sleep(0.2) assert ti.state == State.RUNNING ti.state = State.SUCCESS session.merge(ti) session.commit() process.join(timeout=10) assert success_callback_called.value == 1 assert task_terminated_externally.value == 1 assert not process.is_alive()
def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) operator.render_template_fields(ti.get_template_context()) query = {"dag_id": operator.external_dag_id, "execution_date": dttm.isoformat()} return build_airflow_url_with_query(query)
def test_retry_on_error_sending_task(self): """Test that Airflow retries publishing tasks to Celery Broker at least 3 times""" with _prepare_app(), self.assertLogs( celery_executor.log) as cm, mock.patch.object( # Mock `with timeout()` to _instantly_ fail. celery_executor.timeout, "__enter__", side_effect=AirflowTaskTimeout, ): executor = celery_executor.CeleryExecutor() assert executor.task_publish_retries == {} assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3" task = BashOperator(task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()) when = datetime.now() value_tuple = ( 'command', 1, None, SimpleTaskInstance( ti=TaskInstance(task=task, execution_date=datetime.now())), ) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple # Test that when heartbeat is called again, task is published again to Celery Queue executor.heartbeat() assert dict(executor.task_publish_retries) == {key: 2} assert 1 == len( executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:" f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output) executor.heartbeat() assert dict(executor.task_publish_retries) == {key: 3} assert 1 == len( executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:" f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output) executor.heartbeat() assert dict(executor.task_publish_retries) == {key: 4} assert 1 == len( executor.queued_tasks), "Task should remain in queue" assert executor.event_buffer == {} assert ("INFO:airflow.executors.celery_executor.CeleryExecutor:" f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output) executor.heartbeat() assert dict(executor.task_publish_retries) == {} assert 0 == len( executor.queued_tasks), "Task should no longer be in queue" assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') return 'https://console.cloud.google.com/bigquery?j={job_id}'.format( job_id=job_id) if job_id else ''
def __create_task_instance(self, task): ti = TaskInstance(task=task, execution_date=datetime.now()) return ti