def test_delete_pool(self): self.client.create_pool(name='foo', slots=1, description='') with create_session() as session: self.assertEqual(session.query(models.Pool).count(), 1) self.client.delete_pool(name='foo') with create_session() as session: self.assertEqual(session.query(models.Pool).count(), 0)
def test_delete_dag(self): key = "my_dag_id" with create_session() as session: self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0) session.add(DagModel(dag_id=key)) with create_session() as session: self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 1) self.client.delete_dag(dag_id=key) self.assertEqual(session.query(DagModel).filter(DagModel.dag_id == key).count(), 0)
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception
def snapshot_state(dag, execution_dates): TI = models.TaskInstance with create_session() as session: return session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates) ).all()
def tearDown(self): self.dag1.clear() self.dag2.clear() with create_session() as session: session.query(models.DagRun).delete() session.query(models.TaskInstance).delete()
def setUp(self): self.key = "test_dag_id" task = DummyOperator(task_id='dummy', dag=models.DAG(dag_id=self.key, default_args={'start_date': days_ago(2)}), owner='airflow') d = days_ago(1) with create_session() as session: session.add(DM(dag_id=self.key)) session.add(DR(dag_id=self.key)) session.add(TI(task=task, execution_date=d, state=State.SUCCESS)) # flush to ensure task instance if written before # task reschedule because of FK constraint session.flush() session.add(LOG(dag_id=self.key, task_id=None, task_instance=None, execution_date=d, event="varimport")) session.add(TF(task=task, execution_date=d, start_date=d, end_date=d)) session.add(TR(task=task, execution_date=d, start_date=d, end_date=d, try_number=1, reschedule_date=d))
def test_none_skipped_tr_success(self): """ None-skipped trigger rule success """ ti = self._get_task_instance(TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"]) with create_session() as session: dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=2, skipped=0, failed=1, upstream_failed=0, done=3, flag_upstream_failed=False, session=session)) self.assertEqual(len(dep_statuses), 0) # with `flag_upstream_failed` set to True dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=0, skipped=0, failed=3, upstream_failed=0, done=3, flag_upstream_failed=True, session=session)) self.assertEqual(len(dep_statuses), 0)
def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception
def test_none_skipped_tr_failure(self): """ None-skipped trigger rule failure """ ti = self._get_task_instance(TriggerRule.NONE_SKIPPED, upstream_task_ids=["FakeTaskID", "SkippedTaskID"]) with create_session() as session: dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, skipped=1, failed=0, upstream_failed=0, done=2, flag_upstream_failed=False, session=session)) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) # with `flag_upstream_failed` set to True dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, skipped=1, failed=0, upstream_failed=0, done=2, flag_upstream_failed=True, session=session)) self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed)
def tearDown(self): self.dag1.clear() self.dag2.clear() # just to make sure we are fully cleaned up with create_session() as session: session.query(models.DagRun).delete() session.query(models.TaskInstance).delete()
def tearDown(self): with create_session() as session: session.query(TR).filter(TR.dag_id == self.key).delete() session.query(TF).filter(TF.dag_id == self.key).delete() session.query(TI).filter(TI.dag_id == self.key).delete() session.query(DR).filter(DR.dag_id == self.key).delete() session.query(DM).filter(DM.dag_id == self.key).delete() session.query(LOG).filter(LOG.dag_id == self.key).delete()
def test_delete_dag_successful_delete(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1) self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) delete_dag(dag_id=self.key) with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0) self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1)
def test_create_pool_existing(self): pool = pool_api.create_pool(name=self.pools[0].pool, slots=5, description='') self.assertEqual(pool.pool, self.pools[0].pool) self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') with create_session() as session: self.assertEqual(session.query(models.Pool).count(), 2)
def test_create_pool(self): pool = pool_api.create_pool(name='foo', slots=5, description='') self.assertEqual(pool.pool, 'foo') self.assertEqual(pool.slots, 5) self.assertEqual(pool.description, '') with create_session() as session: self.assertEqual(session.query(models.Pool).count(), 3)
def test_delete_dag_successful_delete_not_keeping_records_in_log(self): with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 1) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 1) self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 1) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 1) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 1) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 1) delete_dag(dag_id=self.key, keep_records_in_log=False) with create_session() as session: self.assertEqual(session.query(DM).filter(DM.dag_id == self.key).count(), 0) self.assertEqual(session.query(DR).filter(DR.dag_id == self.key).count(), 0) self.assertEqual(session.query(TI).filter(TI.dag_id == self.key).count(), 0) self.assertEqual(session.query(TF).filter(TF.dag_id == self.key).count(), 0) self.assertEqual(session.query(TR).filter(TR.dag_id == self.key).count(), 0) self.assertEqual(session.query(LOG).filter(LOG.dag_id == self.key).count(), 0)
def default_action_log(log, **_): """ A default action logger callback that behave same as www.utils.action_logging which uses global session and pushes log ORM object. :param log: An log ORM instance :param **_: other keyword arguments that is not being used by this function :return: None """ with create_session() as session: session.add(log)
def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() for var in TI_CONTEXT_ENV_VARS: if var in os.environ: del os.environ[var]
def setUp(self): self.pools = [] for i in range(2): name = 'experimental_%s' % (i + 1) pool = models.Pool( pool=name, slots=i, description=name, ) self.pools.append(pool) with create_session() as session: session.add_all(self.pools)
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False dag = DAG('shortcircuit_operator_test_without_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': # should not exist raise Exception elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception value = True dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': # should not exist raise Exception elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: raise Exception
def test_deactivate_unknown_dags(self): """ Test that dag_ids not passed into deactivate_unknown_dags are deactivated when function is invoked """ dagbag = DagBag(include_examples=True) dag_id = "test_deactivate_unknown_dags" expected_active_dags = dagbag.dags.keys() model_before = DagModel(dag_id=dag_id, is_active=True) with create_session() as session: session.merge(model_before) models.DAG.deactivate_unknown_dags(expected_active_dags) after_model = DagModel.get_dagmodel(dag_id) self.assertTrue(model_before.is_active) self.assertFalse(after_model.is_active) # clean up with create_session() as session: session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
def _labels_to_key(self, labels): try_num = 1 try: try_num = int(labels.get('try_number', '1')) except ValueError: self.log.warn("could not get try_number as an int: %s", labels.get('try_number', '1')) try: dag_id = labels['dag_id'] task_id = labels['task_id'] ex_time = self._label_safe_datestring_to_datetime(labels['execution_date']) except Exception as e: self.log.warn( 'Error while retrieving labels; labels: %s; exception: %s', labels, e ) return None with create_session() as session: tasks = ( session .query(TaskInstance) .filter_by(execution_date=ex_time).all() ) self.log.info( 'Checking %s task instances.', len(tasks) ) for task in tasks: if ( self._make_safe_label_value(task.dag_id) == dag_id and self._make_safe_label_value(task.task_id) == task_id and task.execution_date == ex_time ): self.log.info( 'Found matching task %s-%s (%s) with current state of %s', task.dag_id, task.task_id, task.execution_date, task.state ) dag_id = task.dag_id task_id = task.task_id return (dag_id, task_id, ex_time, try_num) self.log.warn( 'Failed to find and match task details to a pod; labels: %s', labels ) return None
def execute(self, context): dro = DagRunOrder(run_id='trig__' + timezone.utcnow().isoformat()) dro = self.python_callable(context, dro) if dro: with create_session() as session: dbag = DagBag(settings.DAGS_FOLDER) trigger_dag = dbag.get_dag(self.trigger_dag_id) dr = trigger_dag.create_dagrun( run_id=dro.run_id, state=State.RUNNING, conf=dro.payload, external_trigger=True) self.log.info("Creating DagRun %s", dr) session.add(dr) session.commit() else: self.log.info("Criteria not met, moving on")
def dag_paused(dag_id, paused): """(Un)pauses a dag""" DagModel = models.DagModel with create_session() as session: orm_dag = ( session.query(DagModel) .filter(DagModel.dag_id == dag_id).first() ) if paused == 'true': orm_dag.is_paused = True else: orm_dag.is_paused = False session.merge(orm_dag) session.commit() return jsonify({'response': 'ok'})
def decorated(*args, **kwargs): from flask import request header = request.headers.get("Authorization") if header: userpass = ''.join(header.split()[1:]) username, password = base64.b64decode(userpass).decode("utf-8").split(":", 1) with create_session() as session: try: authenticate(session, username, password) response = function(*args, **kwargs) response = make_response(response) return response except AuthenticationError: return _forbidden() return _unauthorized()
def _change_state(self, key, state, pod_id): if state != State.RUNNING: self.kube_scheduler.delete_pod(pod_id) try: self.log.info('Deleted pod: %s', str(key)) self.running.pop(key) except KeyError: self.log.debug('Could not find key: %s', str(key)) pass self.event_buffer[key] = state (dag_id, task_id, ex_time, try_number) = key with create_session() as session: item = session.query(TaskInstance).filter_by( dag_id=dag_id, task_id=task_id, execution_date=ex_time ).one() if state: item.state = state session.add(item)
def test_delete_dag_dag_still_in_dagbag(self): with create_session() as session: models_to_check = ['DagModel', 'DagRun', 'TaskInstance'] record_counts = {} for model_name in models_to_check: m = getattr(models, model_name) record_counts[model_name] = session.query(m).filter(m.dag_id == self.dag_id).count() with self.assertRaises(DagFileExists): delete_dag(self.dag_id) # No change should happen in DB for model_name in models_to_check: m = getattr(models, model_name) self.assertEqual( session.query(m).filter( m.dag_id == self.dag_id ).count(), record_counts[model_name] )
def test_kill_zombies(self, mock_ti_handle_failure): """ Test that kill zombies call TIs failure handler with proper context """ dagbag = models.DagBag() with create_session() as session: session.query(TI).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) session.add(ti) session.commit() zombies = [SimpleTaskInstance(ti)] dagbag.kill_zombies(zombies) mock_ti_handle_failure \ .assert_called_with(ANY, configuration.getboolean('core', 'unit_test_mode'), ANY)
def wrapper(*args, **kwargs): if current_user and hasattr(current_user, 'username'): user = current_user.username else: user = '******' log = models.Log( event=f.__name__, task_instance=None, owner=user, extra=str(list(request.args.items())), task_id=request.args.get('task_id'), dag_id=request.args.get('dag_id')) if 'execution_date' in request.args: log.execution_date = timezone.parse(request.args.get('execution_date')) with create_session() as session: session.add(log) session.commit() return f(*args, **kwargs)
def test_find_zombies(self): manager = DagFileProcessorManager( dag_directory='directory', file_paths=['abc.txt'], max_runs=1, processor_factory=MagicMock().return_value, signal_conn=MagicMock(), stat_queue=MagicMock(), result_queue=MagicMock, async_mode=True) dagbag = DagBag(TEST_DAG_FOLDER) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) lj = LJ(ti) lj.state = State.SHUTDOWN lj.id = 1 ti.job_id = lj.id session.add(lj) session.add(ti) session.commit() manager._last_zombie_query_time = timezone.utcnow() - timedelta( seconds=manager._zombie_threshold_secs + 1) zombies = manager._find_zombies() self.assertEqual(1, len(zombies)) self.assertIsInstance(zombies[0], SimpleTaskInstance) self.assertEqual(ti.dag_id, zombies[0].dag_id) self.assertEqual(ti.task_id, zombies[0].task_id) self.assertEqual(ti.execution_date, zombies[0].execution_date) session.query(TI).delete() session.query(LJ).delete()
def test_kill_zombies(self, mock_ti_handle_failure): """ Test that kill zombies call TIs failure handler with proper context """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) with create_session() as session: session.query(TI).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) session.add(ti) session.commit() zombies = [SimpleTaskInstance(ti)] dagbag.kill_zombies(zombies) mock_ti_handle_failure \ .assert_called_with(ANY, configuration.getboolean('core', 'unit_test_mode'), ANY)
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = ChooseBranchOne(task_id="make_choice", dag=self.dag) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with create_session() as session: tis = session.query(TI).filter(TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception
def wrapper(*args, **kwargs): with create_session() as session: if g.user.is_anonymous: user = '******' else: user = g.user.username log = Log( event=f.__name__, task_instance=None, owner=user, extra=str(list(request.args.items())), task_id=request.args.get('task_id'), dag_id=request.args.get('dag_id')) if 'execution_date' in request.args: log.execution_date = pendulum.parse( request.args.get('execution_date')) session.add(log) return f(*args, **kwargs)
def test_not_requeue_non_requeueable_task_instance(self): dag = models.DAG(dag_id='test_not_requeue_non_requeueable_task_instance') # Use BaseSensorOperator because sensor got # one additional DEP in BaseSensorOperator().deps task = BaseSensorOperator( task_id='test_not_requeue_non_requeueable_task_instance_op', dag=dag, pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) ti = TI( task=task, execution_date=timezone.utcnow(), state=State.QUEUED) with create_session() as session: session.add(ti) session.commit() all_deps = RUNNING_DEPS | task.deps all_non_requeueable_deps = all_deps - REQUEUEABLE_DEPS patch_dict = {} for dep in all_non_requeueable_deps: class_name = dep.__class__.__name__ dep_patch = patch('%s.%s.%s' % (dep.__module__, class_name, dep._get_dep_statuses.__name__)) method_patch = dep_patch.start() method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True, 'mock')]) patch_dict[class_name] = (dep_patch, method_patch) for class_name, (dep_patch, method_patch) in patch_dict.items(): method_patch.return_value = iter( [TIDepStatus('mock_' + class_name, False, 'mock')]) ti.run() self.assertEqual(ti.state, State.QUEUED) dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock') for (dep_patch, method_patch) in patch_dict.values(): dep_patch.stop()
def test_rerun_failed_subdag(self): """ When there is an existing DagRun with failed state, reset the DagRun and the corresponding TaskInstances """ dag = DAG('parent', default_args=default_args) subdag = DAG('parent.test', default_args=default_args) subdag_task = SubDagOperator(task_id='test', subdag=subdag, dag=dag, poke_interval=1) dummy_task = DummyOperator(task_id='dummy', dag=subdag) with create_session() as session: dummy_task_instance = TaskInstance( task=dummy_task, execution_date=DEFAULT_DATE, state=State.FAILED, ) session.add(dummy_task_instance) session.commit() sub_dagrun = subdag.create_dagrun( run_id="scheduled__{}".format(DEFAULT_DATE.isoformat()), execution_date=DEFAULT_DATE, state=State.FAILED, external_trigger=True, ) subdag_task._reset_dag_run_and_task_instances( sub_dagrun, execution_date=DEFAULT_DATE) dummy_task_instance.refresh_from_db() self.assertEqual(dummy_task_instance.state, State.NONE) sub_dagrun.refresh_from_db() self.assertEqual(sub_dagrun.state, State.RUNNING)
def __init__(self, stream): with create_session() as session: airflow_variables = { var.key: var.val for var in session.query(Variable) } self._root = os.path.split(stream.name)[0] ctx = { "env": os.environ, "airflow_variables": airflow_variables, "pytz": {name: getattr(pytz, name) for name in pytz.__all__}, "datetime": { name: getattr(datetime, name) for name in ["date", "datetime", "time", "timedelta", "tzinfo"] }, "re": {name: getattr(re, name) for name in re.__all__}, } # Enable Jinja2 in the yaml files yaml_stream = StringIO(Template(stream.read()).render(ctx)) yaml_stream.name = stream.name super().__init__(yaml_stream)
def branched_context(): """ Generic Airflow context fixture with branched tasks that can be passed to a callable that requires context """ with create_session() as session: dag = DAGFactory() branch_a = PythonOperatorFactory(task_id='branch_a', dag=dag) branch_b = PythonOperatorFactory(task_id='branch_b', dag=dag) current_task = PythonOperatorFactory(task_id='current_task', dag=dag) next_task = PythonOperatorFactory(task_id='next_task', dag=dag) branch_a.set_downstream(current_task) branch_b.set_downstream(current_task) # join current_task.set_downstream(next_task) dag_run = dag.create_dagrun(run_id="manual__", start_date=timezone.utcnow(), execution_date=timezone.utcnow(), state=State.RUNNING, conf=None, session=session) ti = None for instance in dag_run.get_task_instances(): if instance.task_id == 'current_task': ti = instance break if ti is None: raise ValueError('Unable to find the current task') ti.task = current_task return ti.get_template_context()
def get_job(self, job_name: Text, execution_id: Text) -> Optional[JobInfo]: with create_session() as session: dag_run = session.query(DagRun).filter( DagRun.run_id == execution_id).first() if dag_run is None: return None task = session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, TaskInstance.task_id == job_name).first() if task is None: return None else: project_name, workflow_name = self.dag_id_to_namespace_workflow( dag_run.dag_id) return JobInfo(job_name=job_name, state=self.airflow_state_to_state(task.state), workflow_execution=WorkflowExecutionInfo( workflow_info=WorkflowInfo( namespace=project_name, workflow_name=workflow_name), execution_id=dag_run.run_id, state=self.airflow_state_to_state( dag_run.state)))
def wrapper(*args, **kwargs): # AnonymousUserMixin() has user attribute but its value is None. if current_user and hasattr(current_user, 'user') and current_user.user: user = current_user.user.username else: user = '******' log = models.Log(event=f.__name__, task_instance=None, owner=user, extra=str(list(request.values.items())), task_id=request.values.get('task_id'), dag_id=request.values.get('dag_id')) if request.values.get('execution_date'): log.execution_date = timezone.parse( request.values.get('execution_date')) with create_session() as session: session.add(log) session.commit() return f(*args, **kwargs)
def setUp(self): self.key = "test_dag_id" task = DummyOperator(task_id='dummy', dag=models.DAG( dag_id=self.key, default_args={'start_date': days_ago(2)}), owner='airflow') test_date = days_ago(1) with create_session() as session: session.add(DM(dag_id=self.key)) session.add(DR(dag_id=self.key)) session.add( TI(task=task, execution_date=test_date, state=State.SUCCESS)) # flush to ensure task instance if written before # task reschedule because of FK constraint session.flush() session.add( LOG(dag_id=self.key, task_id=None, task_instance=None, execution_date=test_date, event="varimport")) session.add( TF(task=task, execution_date=test_date, start_date=test_date, end_date=test_date)) session.add( TR(task=task, execution_date=test_date, start_date=test_date, end_date=test_date, try_number=1, reschedule_date=test_date))
def test_kill_zombies_doesn_nothing(self, mock_ti_handle_failure): """ Test that kill zombies does nothing when job is running and received heartbeat """ dagbag = models.DagBag(dag_folder=self.empty_dir, include_examples=True) with create_session() as session: session.query(TI).delete() session.query(LJ).delete() dag = dagbag.get_dag('example_branch_operator') task = dag.get_task(task_id='run_this_first') ti = TI(task, DEFAULT_DATE, State.RUNNING) lj = LJ(ti) lj.latest_heartbeat = utcnow() lj.state = State.RUNNING lj.id = 1 ti.job_id = lj.id session.add(lj) session.add(ti) session.commit() dagbag.kill_zombies() mock_ti_handle_failure.assert_not_called()
def tearDown(self): super(ShortCircuitOperatorTest, self).tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete()
def tearDown(self): super().tearDown() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete()
def setUpClass(cls): super().setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete()
def clear_rendered_ti_fields(): with create_session() as session: session.query(RenderedTaskInstanceFields).delete()
def main(num_runs, repeat, pre_create_dag_runs, executor_class, dag_ids): """ This script can be used to measure the total "scheduler overhead" of Airflow. By overhead we mean if the tasks executed instantly as soon as they are executed (i.e. they do nothing) how quickly could we schedule them. It will monitor the task completion of the Mock/stub executor (no actual tasks are run) and after the required number of dag runs for all the specified dags have completed all their tasks, it will cleanly shut down the scheduler. The dags you run with need to have an early enough start_date to create the desired number of runs. Care should be taken that other limits (DAG concurrency, pool size etc) are not the bottleneck. This script doesn't help you in that regard. It is recommended to repeat the test at least 3 times (`--repeat=3`, the default) so that you can get somewhat-accurate variance on the reported timing numbers, but this can be disabled for longer runs if needed. """ # Turn on unit test mode so that we don't do any sleep() in the scheduler # loop - not needed on master, but this script can run against older # releases too! os.environ['AIRFLOW__CORE__UNIT_TEST_MODE'] = 'True' os.environ['AIRFLOW__CORE__DAG_CONCURRENCY'] = '500' # Set this so that dags can dynamically configure their end_date os.environ['AIRFLOW_BENCHMARK_MAX_DAG_RUNS'] = str(num_runs) os.environ['PERF_MAX_RUNS'] = str(num_runs) if pre_create_dag_runs: os.environ['AIRFLOW__SCHEDULER__USE_JOB_SCHEDULE'] = 'False' from airflow.jobs.scheduler_job import SchedulerJob from airflow.models.dagbag import DagBag from airflow.utils import db dagbag = DagBag() dags = [] with db.create_session() as session: pause_all_dags(session) for dag_id in dag_ids: dag = dagbag.get_dag(dag_id) dag.sync_to_db(session=session) dags.append(dag) reset_dag(dag, session) next_run_date = dag.normalize_schedule(dag.start_date or min(t.start_date for t in dag.tasks)) for _ in range(num_runs - 1): next_run_date = dag.following_schedule(next_run_date) end_date = dag.end_date or dag.default_args.get('end_date') if end_date != next_run_date: message = ( f"DAG {dag_id} has incorrect end_date ({end_date}) for number of runs! " f"It should be " f" {next_run_date}") sys.exit(message) if pre_create_dag_runs: create_dag_runs(dag, num_runs, session) ShortCircuitExecutor = get_executor_under_test(executor_class) executor = ShortCircuitExecutor(dag_ids_to_watch=dag_ids, num_runs=num_runs) scheduler_job = SchedulerJob(dag_ids=dag_ids, do_pickle=False, executor=executor) executor.scheduler_job = scheduler_job total_tasks = sum(len(dag.tasks) for dag in dags) if 'PYSPY' in os.environ: pid = str(os.getpid()) filename = os.environ.get('PYSPY_O', 'flame-' + pid + '.html') os.spawnlp(os.P_NOWAIT, 'sudo', 'sudo', 'py-spy', 'record', '-o', filename, '-p', pid, '--idle') times = [] # Need a lambda to refer to the _latest_ value for scheduler_job, not just # the initial one code_to_test = lambda: scheduler_job.run() # pylint: disable=unnecessary-lambda for count in range(repeat): gc.disable() start = time.perf_counter() code_to_test() times.append(time.perf_counter() - start) gc.enable() print("Run %d time: %.5f" % (count + 1, times[-1])) if count + 1 != repeat: with db.create_session() as session: for dag in dags: reset_dag(dag, session) executor.reset(dag_ids) scheduler_job = SchedulerJob(dag_ids=dag_ids, do_pickle=False, executor=executor) executor.scheduler_job = scheduler_job print() print() msg = "Time for %d dag runs of %d dags with %d total tasks: %.4fs" if len(times) > 1: print((msg + " (±%.3fs)") % (num_runs, len(dags), total_tasks, statistics.mean(times), statistics.stdev(times))) else: print(msg % (num_runs, len(dags), total_tasks, times[0])) print() print()
def test_create_pool(self): pool = self.client.create_pool(name='foo', slots=1, description='') self.assertEqual(pool, ('foo', 1, '')) with create_session() as session: self.assertEqual(session.query(models.Pool).count(), 1)
def clear_db_dag_code(): with create_session() as session: session.query(DagCode).delete()
def lineage_parent_id(run_id, task): with create_session() as session: job_name = f"{task.dag_id}.{task.task_id}" ids = str(JobIdMapping.get(job_name, run_id, session)) return f"{os.getenv('OPENLINEAGE_NAMESPACE')}/{job_name}/{ids}"
def clear_db_task_instance(): with create_session() as session: session.query(TaskInstance).delete() session.query(TaskExecution).delete() session.query(TaskState).delete()
def clear_db_dag_pickle(): with create_session() as session: session.query(DagPickle).delete()
def clear_db_event_model(): with create_session() as session: session.query(EventModel).delete()
def setUpClass(cls): super(PythonOperatorTest, cls).setUpClass() with create_session() as session: session.query(DagRun).delete() session.query(TI).delete()
def clear_db_variables(): with create_session() as session: session.query(Variable).delete()
def setUp(self): db.clear_db_pools() with create_session() as session: test_pool = Pool(pool='test_pool', slots=1) session.add(test_pool) session.commit()
from airflow import DAG from airflow.models import Variable from airflow.operators.dummy_operator import DummyOperator from airflow.utils.dates import days_ago from airflow.utils.db import create_session with create_session() as session: session.query(Variable).all() with DAG(dag_id="dag_no_top_level_query_fail", schedule_interval=None, start_date=days_ago(1)) as dag: DummyOperator(task_id="test")
def snapshot_state(dag, execution_dates): TI = models.TaskInstance with create_session() as session: return session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all()
def clear_db_connections(): with create_session() as session: session.query(Connection).delete() create_default_connections(session)
def tearDown(self): with create_session() as session: session.query(TaskFail).delete() session.query(TaskReschedule).delete() session.query(models.TaskInstance).delete()
def test_backfill_max_limit_check(self): dag_id = 'test_backfill_max_limit_check' run_id = 'test_dagrun' start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE dag_run_created_cond = threading.Condition() def run_backfill(cond): cond.acquire() # this session object is different than the one in the main thread with create_session() as thread_session: try: dag = self._get_dag_test_max_active_limits(dag_id) # Existing dagrun that is not within the backfill range dag.create_dagrun( run_id=run_id, state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1), start_date=DEFAULT_DATE, ) thread_session.commit() cond.notify() finally: cond.release() thread_session.close() executor = MockExecutor() job = BackfillJob(dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True) job.run() backfill_job_thread = threading.Thread(target=run_backfill, name="run_backfill", args=(dag_run_created_cond, )) dag_run_created_cond.acquire() with create_session() as session: backfill_job_thread.start() try: # at this point backfill can't run since the max_active_runs has been # reached, so it is waiting dag_run_created_cond.wait(timeout=1.5) dagruns = DagRun.find(dag_id=dag_id) dr = dagruns[0] self.assertEqual(1, len(dagruns)) self.assertEqual(dr.run_id, run_id) # allow the backfill to execute # by setting the existing dag run to SUCCESS, # backfill will execute dag runs 1 by 1 dr.set_state(State.SUCCESS) session.merge(dr) session.commit() backfill_job_thread.join() dagruns = DagRun.find(dag_id=dag_id) self.assertEqual(3, len(dagruns)) # 2 from backfill + 1 existing self.assertEqual(dagruns[-1].run_id, dr.run_id) finally: dag_run_created_cond.release()
def set_default_pool_slots(slots): with create_session() as session: default_pool = Pool.get_default_pool(session) default_pool.slots = slots