def test_init_with_template_cluster_label(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) task = QuboleOperator( task_id=TASK_ID, dag=dag, cluster_label='{{ params.cluster_label }}', params={'cluster_label': 'default'}, ) ti = TaskInstance(task, DEFAULT_DATE) ti.render_templates() self.assertEqual(task.cluster_label, 'default')
def test_get_dated_main_runner_handles_day_shift(): dag = DAG(dag_id='test_dag', start_date=datetime.strptime('2019-01-01', '%Y-%m-%d')) execution_date = datetime.strptime('2019-01-01', '%Y-%m-%d').replace(tzinfo=timezone.utc) main_func = PickleMock() runner = op_util.get_dated_main_runner_operator(dag, main_func, timedelta(minutes=1), day_shift=1) ti = TaskInstance(runner, execution_date) ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True) main_func.assert_called_with('2018-12-31')
def __init__(self, ti: TaskInstance, render_templates=True): self.dag_id = ti.dag_id self.task_id = ti.task_id self.task = ti.task self.execution_date = ti.execution_date self.ti = ti if render_templates: ti.render_templates() if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None): self.k8s_pod_yaml = ti.render_k8s_pod_yaml() self.rendered_fields = { field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields }
def test_localtaskjob_maintain_heart_rate(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) # this should make sure we only heartbeat once and exit at the second # loop in _execute() return_codes = [None, 0] def multi_return_code(): return return_codes.pop(0) time_start = time.time() with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start: with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code: mock_ret_code.side_effect = multi_return_code job1.run() assert mock_start.call_count == 1 assert mock_ret_code.call_count == 2 time_end = time.time() assert self.mock_base_job_sleep.call_count == 1 assert job1.state == State.SUCCESS # Consider we have patched sleep call, it should not be sleeping to # keep up with the heart rate in other unpatched places # # We already make sure patched sleep call is only called once assert time_end - time_start < job1.heartrate session.close()
def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = {branch_task_ids} elif branch_task_ids is None: branch_task_ids = () branch_task_ids = set(branch_task_ids) dag_run = ti.get_dagrun() task = ti.task dag = task.dag assert dag # For Mypy. # At runtime, the downstream list will only be operators downstream_tasks = cast("List[BaseOperator]", task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), # we intuitively expect both "task1" and "join" to execute even though strictly speaking, # "join" is also immediately downstream of "branch" and should have been skipped. Therefore, # we need a special case here for such empty branches: Check downstream tasks of branch_task_ids. # In case the task to skip is also downstream of branch_task_ids, we add it to branch_task_ids and # exclude it from skipping. # # branch -----> join # \ ^ # v / # task1 # for branch_task_id in list(branch_task_ids): branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False)) skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids] follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: self._set_state_to_skipped(dag_run, skip_tasks, session=session) # For some reason, session.commit() needs to happen before xcom_push. # Otherwise the session is not committed. session.commit() ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids})
def test_should_continue_with_cp(load_dag): dag_bag = load_dag('bq_to_wrench') dag = dag_bag.get_dag('bq_to_wrench') table = 'staging.users' task = dag.get_task(f'continue_if_data_{table}') assert isinstance(task, BranchPythonOperator) ti = TaskInstance(task=task, execution_date=datetime.now()) XCom.set(key=table, value={'has_data': True}, task_id=task.task_id, dag_id=dag.dag_id, execution_date=ti.execution_date) task.execute(ti.get_template_context())
def test_sets_initial_checkpoint(load_dag, env, bigquery_helper): # Remove all checkpoints for table table = 'staging.users' bigquery_helper.query( f"DELETE FROM `{env['project']}.system.checkpoint` WHERE table = '{table}'" ) # Execute get checkpoint task. I expect it to create an initial checkpoint. dag_bag = load_dag('bq_to_wrench') dag = dag_bag.get_dag('bq_to_wrench') task = dag.get_task(f'get_checkpoint_{table}') assert isinstance(task, GetCheckpointOperator) ti = TaskInstance(task=task, execution_date=datetime.now()) task.execute(ti.get_template_context())
def test_try_adopt_task_instances(self): date = datetime.utcnow() start_date = datetime.utcnow() - timedelta(days=2) with DAG("test_try_adopt_task_instances"): task_1 = BaseOperator(task_id="task_1", start_date=start_date) task_2 = BaseOperator(task_id="task_2", start_date=start_date) task_3 = BaseOperator(task_id="task_3", start_date=start_date) key1 = TaskInstance(task=task_1, execution_date=date) key2 = TaskInstance(task=task_2, execution_date=date) key3 = TaskInstance(task=task_3, execution_date=date) tis = [key1, key2, key3] self.assertEqual(BaseExecutor().try_adopt_task_instances(tis), tis)
def _run_task(self, ti: TaskInstance) -> bool: self.log.debug("Executing task: %s", ti) key = ti.key try: params = self.tasks_params.pop(ti.key, {}) ti._run_raw_task( # pylint: disable=protected-access job_id=ti.job_id, **params) self.change_state(key, State.SUCCESS) return True except Exception as e: # pylint: disable=broad-except self.change_state(key, State.FAILED) self.log.exception("Failed to execute task: %s.", str(e)) return False
def __init__(self, ti: TaskInstance, render_templates=True): self.dag_id = ti.dag_id self.task_id = ti.task_id self.task = ti.task self.execution_date = ti.execution_date self.ti = ti if render_templates: ti.render_templates() if IS_K8S_OR_K8SCELERY_EXECUTOR: self.k8s_pod_yaml = ti.render_k8s_pod_yaml() self.rendered_fields = { field: serialize_template_field(getattr(self.task, field)) for field in self.task.template_fields }
def test_mark_success_on_success_callback(self): """ Test that ensures that where a task is marked suceess in the UI on_success_callback gets executed """ data = {'called': False} def success_callback(context): self.assertEqual(context['dag_run'].dag_id, 'test_mark_success') data['called'] = True dag = DAG(dag_id='test_mark_success', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) task = DummyOperator(task_id='test_state_succeeded1', dag=dag, on_success_callback=success_callback) session = settings.Session() dag.clear() dag.create_dagrun( run_id="test", state=State.RUNNING, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.refresh_from_db() job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner job1.task_runner = StandardTaskRunner(job1) process = multiprocessing.Process(target=job1.run) process.start() ti.refresh_from_db() for _ in range(0, 50): if ti.state == State.RUNNING: break time.sleep(0.1) ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) ti.state = State.SUCCESS session.merge(ti) session.commit() job1.heartbeat_callback(session=None) self.assertTrue(data['called']) process.join(timeout=10) self.assertFalse(process.is_alive())
class PythonIdempatomicFileOperatorTest_Idempotent(unittest.TestCase): def f(self, output_path): with open(output_path, "a+") as fout: fout.write("test") def test_PyIdempaOp_idempotent(self): self.dag = DAG( TEST_DAG_ID, schedule_interval="@daily", default_args={"start_date": datetime.now()}, ) with TemporaryDirectory() as tempdir: output_path = f"{tempdir}/test_file.txt" self.assertFalse( os.path.exists(output_path)) # ensure doesn't already exist self.op = PythonIdempatomicFileOperator( dag=self.dag, task_id="test", output_pattern=output_path, python_callable=self.f, ) self.ti = TaskInstance(task=self.op, execution_date=datetime.now()) result = self.op.execute(self.ti.get_template_context()) self.assertEqual(result, output_path) self.assertFalse(self.op.previously_completed) self.assertTrue(os.path.exists(output_path)) with open(output_path, "r") as fout: self.assertEqual(fout.read(), "test") # now run task again result = self.op.execute(self.ti.get_template_context()) self.assertEqual(result, output_path) # result will still give path self.assertTrue(self.op.previously_completed) # if function had run again, it would now be 'testtest' with open(output_path, "r") as fout: self.assertEqual(fout.read(), "test") # run function again to ensure 'testtest' is written to file upon second call self.f(output_path) with open(output_path, "r") as fout: self.assertEqual(fout.read(), "testtest")
class PythonIdempatomicFileOperatorTest_Atomic(unittest.TestCase): def g(self, output_path): with open(output_path, "w") as fout: fout.write("test") raise ValueError("You cannot write that!") def test_PyIdempaOp_atomic(self): self.dag = DAG( TEST_DAG_ID, schedule_interval="@daily", default_args={"start_date": datetime.now()}, ) with TemporaryDirectory() as tempdir: output_path = f"{tempdir}/test.txt" self.assertFalse( os.path.exists(output_path)) # ensure doesn't already exist self.op = PythonIdempatomicFileOperator( dag=self.dag, task_id="test", output_pattern=output_path, python_callable=self.g, ) self.ti = TaskInstance(task=self.op, execution_date=datetime.now()) with self.assertRaises( ValueError): # ensure ValueError is triggered (since task ran) result = self.op.execute(self.ti.get_template_context()) self.assertEqual(result, None) # make sure no path is returned self.assertFalse( os.path.exists(output_path)) # no partially written file
def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]]): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = [branch_task_ids] dag_run = ti.get_dagrun() task = ti.task dag = task.dag downstream_tasks = task.downstream_list if downstream_tasks: # Also check downstream tasks of the branch task. In case the task to skip # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for b in branch_task_ids: branch_downstream_task_ids.update(dag. get_task(b). get_flat_relative_ids(upstream=False)) skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids and t.task_id not in branch_downstream_task_ids] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) self.skip(dag_run, ti.execution_date, skip_tasks)
def _get_ready_tis( self, scheduleable_tasks: List[TI], finished_tasks: List[TI], session: Session, ) -> Tuple[List[TI], bool]: old_states = {} ready_tis: List[TI] = [] changed_tis = False if not scheduleable_tasks: return ready_tis, changed_tis # Check dependencies for st in scheduleable_tasks: old_state = st.state if st.are_dependencies_met(dep_context=DepContext( flag_upstream_failed=True, finished_tasks=finished_tasks), session=session): ready_tis.append(st) else: old_states[st.key] = old_state # Check if any ti changed state tis_filter = TI.filter_for_tis(old_states.keys()) if tis_filter is not None: fresh_tis = session.query(TI).filter(tis_filter).all() changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) return ready_tis, changed_tis
def _start_task_instance(self, key: TaskInstanceKey): """ Ignore all dependencies, force start a task instance """ ti = self.get_task_instance(key) if ti is None: self.log.error("TaskInstance not found in DB, %s.", str(key)) return command = TaskInstance.generate_command( ti.dag_id, ti.task_id, ti.execution_date, local=True, mark_success=False, ignore_all_deps=True, ignore_depends_on_past=True, ignore_task_deps=True, ignore_ti_state=True, pool=ti.pool, file_path=ti.dag_model.fileloc, pickle_id=ti.dag_model.pickle_id, server_uri=self._server_uri, ) ti.set_state(State.QUEUED) self.execute_async(key=key, command=command, queue=ti.queue, executor_config=ti.executor_config)
def test_error_sending_task(self): def fake_execute_command(): pass with _prepare_app(execute=fake_execute_command): # fake_execute_command takes no arguments while execute_command takes 1, # which will cause TypeError when calling task.apply_async() executor = celery_executor.CeleryExecutor() task = BashOperator(task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()) when = datetime.now() value_tuple = ( 'command', 1, None, SimpleTaskInstance( ti=TaskInstance(task=task, execution_date=datetime.now())), ) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple executor.task_publish_retries[key] = 1 executor.heartbeat() assert 0 == len( executor.queued_tasks), "Task should no longer be queued" assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
def _send_stalled_tis_back_to_scheduler( self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION) -> None: try: session.query(TaskInstance).filter( TaskInstance.filter_for_tis(keys), TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id, ).update( { TaskInstance.state: State.SCHEDULED, TaskInstance.queued_dttm: None, TaskInstance.queued_by_job_id: None, TaskInstance.external_executor_id: None, }, synchronize_session=False, ) session.commit() except Exception: self.log.exception("Error sending tasks back to scheduler") session.rollback() return for key in keys: self._set_celery_pending_task_timeout(key, None) self.running.discard(key) celery_async_result = self.tasks.pop(key, None) if celery_async_result: try: app.control.revoke(celery_async_result.task_id) except Exception as ex: self.log.error( "Error revoking task instance %s from celery: %s", key, ex)
def verify_integrity(self, session=None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. """ from airflow.models.taskinstance import TaskInstance # Avoid circular import dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning("Failed to get task '{}' for dag '{}'. " "Marking it as removed.".format(ti, dag)) Stats.incr( "task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED is_task_in_dag = task is not None should_restore_task = is_task_in_dag and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '{}' which was previously " "removed from DAG '{}'".format(ti, dag)) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in six.itervalues(dag.task_dict): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr( "task_instance_created-{}".format(task.__class__.__name__), 1, 1) ti = TaskInstance(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) try: session.commit() except IntegrityError as err: self.log.info(str(err)) self.log.info( 'Hit IntegrityError while creating the TIs for %s - %s', dag.dag_id, self.execution_date ) self.log.info('Doing session rollback.') session.rollback()
def queue_task_instance(self, task_instance: TaskInstance, mark_success: bool = False, pickle_id: Optional[str] = None, ignore_all_deps: bool = False, ignore_depends_on_past: bool = False, ignore_task_deps: bool = False, ignore_ti_state: bool = False, pool: Optional[str] = None, cfg_path: Optional[str] = None) -> None: """Queues task instance.""" pool = pool or task_instance.pool # TODO (edgarRd): AIRFLOW-1985: # cfg_path is needed to propagate the config values if using impersonation # (run_as_user), given that there are different code paths running tasks. # For a long term solution we need to address AIRFLOW-1986 command_list_to_run = task_instance.command_as_list( local=True, mark_success=mark_success, ignore_all_deps=ignore_all_deps, ignore_depends_on_past=ignore_depends_on_past, ignore_task_deps=ignore_task_deps, ignore_ti_state=ignore_ti_state, pool=pool, pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command(SimpleTaskInstance(task_instance), command_list_to_run, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue)
def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator: created_counts[task.task_type] += 1 for map_index in indexes: yield TI.insert_mapping(self.run_id, task, map_index=map_index)
def _get_ready_tis( self, schedulable_tis: List[TI], finished_tis: List[TI], session: Session, ) -> Tuple[List[TI], bool]: old_states = {} ready_tis: List[TI] = [] changed_tis = False if not schedulable_tis: return ready_tis, changed_tis # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter # `schedulable_tis` in place and have the `for` loop pick them up expanded_tis: List[TI] = [] dep_context = DepContext( flag_upstream_failed=True, ignore_unmapped_tasks= True, # Ignore this Dep, as we will expand it if we can. finished_tis=finished_tis, ) # Check dependencies for schedulable in itertools.chain(schedulable_tis, expanded_tis): old_state = schedulable.state if schedulable.are_dependencies_met(session=session, dep_context=dep_context): ready_tis.append(schedulable) else: old_states[schedulable.key] = old_state continue # Expansion of last resort! This is ideally handled in the mini-scheduler in LocalTaskJob, but if # for any reason it wasn't, we need to expand it now if schedulable.map_index < 0 and schedulable.task.is_mapped: # HACK. This needs a better way, one that copes with multiple upstreams! for ti in finished_tis: if schedulable.task_id in ti.task.downstream_task_ids: assert isinstance(schedulable.task, MappedOperator) new_tis = schedulable.task.expand_mapped_task( self.run_id, session=session) if schedulable.state == TaskInstanceState.SKIPPED: # Task is now skipped (likely cos upstream returned 0 tasks continue assert new_tis[0] is schedulable expanded_tis.extend(new_tis[1:]) break # Check if any ti changed state tis_filter = TI.filter_for_tis(old_states.keys()) if tis_filter is not None: fresh_tis = session.query(TI).filter(tis_filter).all() changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) return ready_tis, changed_tis
def get_link(self, operator, dttm): """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id']) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') url = host + str(qds_command_id) if qds_command_id else '' return url
def skip_all_except( self, ti: TaskInstance, branch_task_ids: Union[str, Iterable[str]] ): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, str): branch_task_ids = [branch_task_ids] dag_run = ti.get_dagrun() task = ti.task dag = task.dag downstream_tasks = task.downstream_list if downstream_tasks: # Also check downstream tasks of the branch task. In case the task to skip # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for branch_task_id in branch_task_ids: branch_downstream_task_ids.update( dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False) ) skip_tasks = [ t for t in downstream_tasks if t.task_id not in branch_task_ids and t.task_id not in branch_downstream_task_ids ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) with create_session() as session: self._set_state_to_skipped( dag_run, ti.execution_date, skip_tasks, session=session ) ti.xcom_push( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids} )
def _get_ready_tis( self, schedulable_tis: List[TI], finished_tis: List[TI], session: Session, ) -> Tuple[List[TI], bool, bool]: old_states = {} ready_tis: List[TI] = [] changed_tis = False if not schedulable_tis: return ready_tis, changed_tis, False # If we expand TIs, we need a new list so that we iterate over them too. (We can't alter # `schedulable_tis` in place and have the `for` loop pick them up additional_tis: List[TI] = [] dep_context = DepContext( flag_upstream_failed=True, ignore_unmapped_tasks= True, # Ignore this Dep, as we will expand it if we can. finished_tis=finished_tis, ) # Check dependencies. expansion_happened = False for schedulable in itertools.chain(schedulable_tis, additional_tis): old_state = schedulable.state if not schedulable.are_dependencies_met(session=session, dep_context=dep_context): old_states[schedulable.key] = old_state continue # If schedulable is from a mapped task, but not yet expanded, do it # now. This is called in two places: First and ideally in the mini # scheduler at the end of LocalTaskJob, and then as an "expansion of # last resort" in the scheduler to ensure that the mapped task is # correctly expanded before executed. if schedulable.map_index < 0 and isinstance( schedulable.task, MappedOperator): expanded_tis, _ = schedulable.task.expand_mapped_task( self.run_id, session=session) if expanded_tis: assert expanded_tis[0] is schedulable additional_tis.extend(expanded_tis[1:]) expansion_happened = True if schedulable.state in SCHEDULEABLE_STATES: ready_tis.append(schedulable) # Check if any ti changed state tis_filter = TI.filter_for_tis(old_states) if tis_filter is not None: fresh_tis = session.query(TI).filter(tis_filter).all() changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis) return ready_tis, changed_tis, expansion_happened
def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=t1, execution_date=DEFAULT_DATE) ti2 = TI(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(3, pool.open_slots())
def test_get_redirect_url(self): dag = DAG(DAG_ID, start_date=DEFAULT_DATE) with dag: task = QuboleOperator(task_id=TASK_ID, qubole_conn_id=TEST_CONN, command_type='shellcmd', parameters="param1 param2", dag=dag) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti.xcom_push('qbol_cmd_id', 12345) # check for positive case url = task.get_extra_links(DEFAULT_DATE, 'Go to QDS') self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345') # check for negative case url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') self.assertEqual(url2, '')
def test_localtaskjob_double_trigger(self): dagbag = DagBag( dag_folder=TEST_DAG_FOLDER, include_examples=False, ) dag = dagbag.dags.get('test_localtaskjob_double_trigger') task = dag.get_task('test_localtaskjob_double_trigger_task') session = settings.Session() dag.clear() dr = dag.create_dagrun( run_id="test", state=State.SUCCESS, execution_date=DEFAULT_DATE, start_date=DEFAULT_DATE, session=session, ) ti = dr.get_task_instance(task_id=task.task_id, session=session) ti.state = State.RUNNING ti.hostname = get_hostname() ti.pid = 1 session.merge(ti) session.commit() ti_run = TaskInstance(task=task, execution_date=DEFAULT_DATE) ti_run.refresh_from_db() job1 = LocalTaskJob(task_instance=ti_run, executor=SequentialExecutor()) from airflow.task.task_runner.standard_task_runner import StandardTaskRunner with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_method: job1.run() mock_method.assert_not_called() ti = dr.get_task_instance(task_id=task.task_id, session=session) self.assertEqual(ti.pid, 1) self.assertEqual(ti.state, State.RUNNING) session.close()
def skip(self, dag_run, execution_date, tasks, session=None): """ Sets tasks instances to skipped from the same dag run. :param dag_run: the DagRun for which to set the tasks to skipped :param execution_date: execution_date :param tasks: tasks to skip (not task_ids) :param session: db session to use """ if not tasks: return task_ids = [d.task_id for d in tasks] now = timezone.utcnow() if dag_run: session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, TaskInstance.task_id.in_(task_ids)).update( { TaskInstance.state: State.SKIPPED, TaskInstance.start_date: now, TaskInstance.end_date: now }, synchronize_session=False) session.commit() else: if execution_date is None: raise ValueError("Execution date is None and no dag run") self.log.warning("No DAG RUN present this should not happen") # this is defensive against dag runs that are not complete for task in tasks: ti = TaskInstance(task, execution_date=execution_date) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) session.commit()
def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(3, pool.open_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter self.assertEqual( { "default_pool": { "open": 128, "queued": 0, "total": 128, "running": 0, }, "test_pool": { "open": 3, "queued": 1, "running": 1, "total": 5, }, }, pool.slots_stats(), )
def skip(self, dag_run, execution_date, tasks, session=None): """ Sets tasks instances to skipped from the same dag run. :param dag_run: the DagRun for which to set the tasks to skipped :param execution_date: execution_date :param tasks: tasks to skip (not task_ids) :param session: db session to use """ if not tasks: return task_ids = [d.task_id for d in tasks] now = timezone.utcnow() if dag_run: session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, TaskInstance.task_id.in_(task_ids) ).update({TaskInstance.state: State.SKIPPED, TaskInstance.start_date: now, TaskInstance.end_date: now}, synchronize_session=False) session.commit() else: assert execution_date is not None, "Execution date is None and no dag run" self.log.warning("No DAG RUN present this should not happen") # this is defensive against dag runs that are not complete for task in tasks: ti = TaskInstance(task, execution_date=execution_date) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) session.commit()
def get_link(self, operator, dttm): ti = TaskInstance(task=operator, execution_date=dttm) job_id = ti.xcom_pull(task_ids=operator.task_id, key='job_id') return 'https://console.cloud.google.com/bigquery?j={job_id}'.format( job_id=job_id) if job_id else ''