def wrapper(*args, **kwargs): with create_session() as session: if g.user.is_anonymous: user = '******' else: user = g.user.username fields_skip_logging = {'csrf_token', '_csrf_token'} log = Log( event=f.__name__, task_instance=None, owner=user, extra=str([(k, v) for k, v in request.values.items() if k not in fields_skip_logging]), task_id=request.values.get('task_id'), dag_id=request.values.get('dag_id'), ) if 'execution_date' in request.values: log.execution_date = pendulum.parse(request.values.get('execution_date'), strict=False) session.add(log) return f(*args, **kwargs)
def test_sync_perm_for_dag(self, mock_security_manager): """ Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag when DAG specific perm views don't exist already or the DAG has access_control set. """ delete_dag_specific_permissions() with create_session() as session: security_manager = ApplessAirflowSecurityManager(session) mock_sync_perm_for_dag = mock_security_manager.return_value.sync_perm_for_dag mock_sync_perm_for_dag.side_effect = security_manager.sync_perm_for_dag dagbag = DagBag( dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), include_examples=False, ) dag = dagbag.dags["test_example_bash_operator"] def _sync_perms(): mock_sync_perm_for_dag.reset_mock() dagbag._sync_perm_for_dag(dag, session=session) # permviews dont exist _sync_perms() mock_sync_perm_for_dag.assert_called_once_with( "test_example_bash_operator", None) # permviews now exist _sync_perms() mock_sync_perm_for_dag.assert_not_called() # Always sync if we have access_control dag.access_control = {"Public": {"can_read"}} _sync_perms() mock_sync_perm_for_dag.assert_called_once_with( "test_example_bash_operator", {"Public": {"can_read"}})
def test_cli_connection_add(self, cmd, expected_output, expected_conn): with redirect_stdout(io.StringIO()) as stdout: connection_command.connections_add(self.parser.parse_args(cmd)) stdout = stdout.getvalue() self.assertIn(expected_output, stdout) conn_id = cmd[2] with create_session() as session: comparable_attrs = [ "conn_type", "host", "is_encrypted", "is_extra_encrypted", "login", "port", "schema", ] current_conn = session.query(Connection).filter( Connection.conn_id == conn_id).first() self.assertEqual(expected_conn, { attr: getattr(current_conn, attr) for attr in comparable_attrs })
def test_parent_follow_branch(): """ A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. """ start_date = pendulum.datetime(2020, 1, 1) dag = DAG("test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date) dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op1 >> op2 TaskInstance(op1, start_date).run() ti2 = TaskInstance(op2, start_date) with create_session() as session: dep = NotPreviouslySkippedDep() assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 assert dep.is_met(ti2, session) assert ti2.state != State.SKIPPED
def downgrade(): """Make TaskInstance.pool field nullable.""" conn = op.get_bind() if conn.dialect.name == "mssql": op.drop_index('ti_pool', table_name='task_instance') # use batch_alter_table to support SQLite workaround with op.batch_alter_table('task_instance') as batch_op: batch_op.alter_column( column_name='pool', type_=sa.String(50), nullable=True, ) if conn.dialect.name == "mssql": op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight']) with create_session() as session: session.query(TaskInstance).filter( TaskInstance.pool == 'default_pool').update( {TaskInstance.pool: None}, synchronize_session=False) # Avoid select updated rows session.commit()
def test_extra_link_in_gantt_view(dag, viewer_client): exec_date = dates.days_ago(2) start_date = timezone.datetime(2020, 4, 10, 2, 0, 0) end_date = exec_date + datetime.timedelta(seconds=30) with create_session() as session: for task in dag.tasks: ti = TaskInstance(task=task, execution_date=exec_date, state="success") ti.start_date = start_date ti.end_date = end_date session.add(ti) url = f'gantt?dag_id={dag.dag_id}&execution_date={exec_date}' resp = viewer_client.get(url, follow_redirects=True) check_content_in_response('"extraLinks":', resp) extra_links_grps = re.search(r'extraLinks\": \[(\".*?\")\]', resp.get_data(as_text=True)) extra_links = extra_links_grps.group(0) assert 'airflow' in extra_links assert 'github' in extra_links
def heartbeat(self): if not self.do_update: return with create_session() as session: self.history.append(list(self.queued_tasks.values())) # Create a stable/predictable sort order for events in self.history # for tests! def sort_by(item): key, val = item (dag_id, task_id, date, try_number) = key (_, prio, _, _) = val # Sort by priority (DESC), then date,task, try return -prio, date, dag_id, task_id, try_number open_slots = self.parallelism - len(self.running) sorted_queue = sorted(self.queued_tasks.items(), key=sort_by) for index in range(min((open_slots, len(sorted_queue)))): (key, (_, _, _, ti)) = sorted_queue[index] self.queued_tasks.pop(key) state = self.mock_task_results[key] ti.set_state(state, session=session) self.change_state(key, state)
def run_task_function(client: NotificationClient): with af.global_config_file(workflow_config_file()): with af.config('task_1'): cmd_executor = af.user_define_operation( output_num=0, executor=CmdExecutor( cmd_line='echo "hello world" && sleep 30'.format( 1))) workflow_info = af.workflow_operation.submit_workflow( workflow_name) we = af.workflow_operation.start_new_workflow_execution( workflow_name) self.assertEqual(project_name, we.workflow_info.namespace) self.assertEqual(workflow_name, we.workflow_info.workflow_name) we_2 = af.workflow_operation.get_workflow_execution( we.execution_id) self.assertEqual(we.execution_id, we_2.execution_id) self.assertEqual(project_name, we_2.workflow_info.namespace) self.assertEqual(workflow_name, we_2.workflow_info.workflow_name) we_list = af.workflow_operation.list_workflow_executions( workflow_name) self.assertEqual(1, len(we_list)) while True: with create_session() as session: ti = session.query(TaskInstance) \ .filter(TaskInstance.dag_id == dag_id).first() if ti is not None and ti.state == State.RUNNING: af.workflow_operation.kill_workflow_execution( we.execution_id) elif ti.state == State.KILLED: break else: time.sleep(1)
def test_next_execution(self): dag_ids = [ 'example_bash_operator', # schedule_interval is '0 0 * * *' 'latest_only', # schedule_interval is timedelta(hours=4) 'example_python_operator', # schedule_interval=None 'example_xcom' ] # schedule_interval="@once" # Delete DagRuns with create_session() as session: dr = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) dr.delete(synchronize_session=False) # Test None output args = self.parser.parse_args(['dags', 'next-execution', dag_ids[0]]) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() # `next_execution` function is inapplicable if no execution record found # It prints `None` in such cases self.assertIn("None", out) # The details below is determined by the schedule_interval of example DAGs now = DEFAULT_DATE expected_output = [ str(now + timedelta(days=1)), str(now + timedelta(hours=4)), "None", "None" ] expected_output_2 = [ str(now + timedelta(days=1)) + os.linesep + str(now + timedelta(days=2)), str(now + timedelta(hours=4)) + os.linesep + str(now + timedelta(hours=8)), "None", "None" ] for i, dag_id in enumerate(dag_ids): dag = self.dagbag.dags[dag_id] # Create a DagRun for each DAG, to prepare for next step dag.create_dagrun(run_type=DagRunType.MANUAL, execution_date=now, start_date=now, state=State.FAILED) # Test num-executions = 1 (default) args = self.parser.parse_args(['dags', 'next-execution', dag_id]) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() self.assertIn(expected_output[i], out) # Test num-executions = 2 args = self.parser.parse_args( ['dags', 'next-execution', dag_id, '--num-executions', '2']) with contextlib.redirect_stdout(io.StringIO()) as temp_stdout: dag_command.dag_next_execution(args) out = temp_stdout.getvalue() self.assertIn(expected_output_2[i], out) # Clean up before leaving with create_session() as session: dr = session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)) dr.delete(synchronize_session=False)
def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor(self): """ Check that the same set of failure callback with zombies are passed to the dag file processors until the next zombie detection logic is invoked. """ test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') with conf_vars({('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False'}): dagbag = DagBag(test_dag_path, read_dags_from_db=False) with create_session() as session: session.query(LJ).delete() dag = dagbag.get_dag('test_example_bash_operator') dag.sync_to_db() task = dag.get_task(task_id='run_this_last') ti = TI(task, DEFAULT_DATE, State.RUNNING) local_job = LJ(ti) local_job.state = State.SHUTDOWN session.add(local_job) session.commit() # TODO: If there was an actual Relationship between TI and Job # we wouldn't need this extra commit session.add(ti) ti.job_id = local_job.id session.commit() expected_failure_callback_requests = [ TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message", ) ] test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py') child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') fake_processors = [] def fake_processor_factory(*args, **kwargs): nonlocal fake_processors processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(*args, **kwargs) fake_processors.append(processor) return processor manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, processor_factory=fake_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode, ) self.run_processor_manager_one_loop(manager, parent_pipe) if async_mode: # Once for initial parse, and then again for the add_callback_to_queue assert len(fake_processors) == 2 assert fake_processors[0]._file_path == test_dag_path assert fake_processors[0]._callback_requests == [] else: assert len(fake_processors) == 1 assert fake_processors[-1]._file_path == test_dag_path callback_requests = fake_processors[-1]._callback_requests assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == { result.simple_task_instance.key for result in callback_requests } child_pipe.close() parent_pipe.close()
def test_delete_pool(self): pool = pool_api.delete_pool(name=self.pools[-1].pool) assert pool.pool == self.pools[-1].pool with create_session() as session: assert session.query( models.Pool).count() == self.TOTAL_POOL_COUNT - 1
def clear_db_import_errors(): with create_session() as session: session.query(errors.ImportError).delete()
def set_default_pool_slots(slots): with create_session() as session: default_pool = Pool.get_default_pool(session) default_pool.slots = slots
def clear_db_variables(): with create_session() as session: session.query(Variable).delete()
def test_create_pool(self): pool = self.client.create_pool(name='foo', slots=1, description='') self.assertEqual(pool, ('foo', 1, '')) with create_session() as session: self.assertEqual(session.query(Pool).count(), 2)
def clear_db_pools(): with create_session() as session: session.query(Pool).delete() add_default_pool_if_not_exists(session)
def clear_db_connections(): with create_session() as session: session.query(Connection).delete() create_default_connections(session)
def test_backfill_max_limit_check(self): dag_id = 'test_backfill_max_limit_check' run_id = 'test_dagrun' start_date = DEFAULT_DATE - datetime.timedelta(hours=1) end_date = DEFAULT_DATE dag_run_created_cond = threading.Condition() def run_backfill(cond): cond.acquire() # this session object is different than the one in the main thread with create_session() as thread_session: try: dag = self._get_dag_test_max_active_limits(dag_id) # Existing dagrun that is not within the backfill range dag.create_dagrun( run_id=run_id, state=State.RUNNING, execution_date=DEFAULT_DATE + datetime.timedelta(hours=1), start_date=DEFAULT_DATE, ) thread_session.commit() cond.notify() finally: cond.release() thread_session.close() executor = MockExecutor() job = BackfillJob(dag=dag, start_date=start_date, end_date=end_date, executor=executor, donot_pickle=True) job.run() backfill_job_thread = threading.Thread(target=run_backfill, name="run_backfill", args=(dag_run_created_cond, )) dag_run_created_cond.acquire() with create_session() as session: backfill_job_thread.start() try: # at this point backfill can't run since the max_active_runs has been # reached, so it is waiting dag_run_created_cond.wait(timeout=1.5) dagruns = DagRun.find(dag_id=dag_id) dr = dagruns[0] self.assertEqual(1, len(dagruns)) self.assertEqual(dr.run_id, run_id) # allow the backfill to execute # by setting the existing dag run to SUCCESS, # backfill will execute dag runs 1 by 1 dr.set_state(State.SUCCESS) session.merge(dr) session.commit() backfill_job_thread.join() dagruns = DagRun.find(dag_id=dag_id) self.assertEqual(3, len(dagruns)) # 2 from backfill + 1 existing self.assertEqual(dagruns[-1].run_id, dr.run_id) finally: dag_run_created_cond.release()
def clear_db_dag_code(): with create_session() as session: session.query(DagCode).delete()
def setUp(self) -> None: self.client = self.app.test_client() # type:ignore # we want only the connection created here for this test with create_session() as session: session.query(Connection).delete()
def clear_rendered_ti_fields(): with create_session() as session: session.query(RenderedTaskInstanceFields).delete()
def setUp(self) -> None: with create_session() as session: session.query(Log).delete() self.default_time = "2020-06-09T13:00:00+00:00" self.default_time2 = '2020-06-11T07:00:00+00:00'
def heartbeat(self, only_if_necessary: bool = False): """ Heartbeats update the job's entry in the database with a timestamp for the latest_heartbeat and allows for the job to be killed externally. This allows at the system level to monitor what is actually active. For instance, an old heartbeat for SchedulerJob would mean something is wrong. This also allows for any job to be killed externally, regardless of who is running it or on which machine it is running. Note that if your heart rate is set to 60 seconds and you call this method after 10 seconds of processing since the last heartbeat, it will sleep 50 seconds to complete the 60 seconds and keep a steady heart rate. If you go over 60 seconds before calling it, it won't sleep at all. :param only_if_necessary: If the heartbeat is not yet due then do nothing (don't update column, don't call ``heartbeat_callback``) :type only_if_necessary: boolean """ seconds_remaining = 0 if self.latest_heartbeat: seconds_remaining = self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds() if seconds_remaining > 0 and only_if_necessary: return previous_heartbeat = self.latest_heartbeat try: with create_session() as session: # This will cause it to load from the db session.merge(self) previous_heartbeat = self.latest_heartbeat if self.state in State.terminating_states: self.kill() # Figure out how long to sleep for sleep_for = 0 if self.latest_heartbeat: seconds_remaining = ( self.heartrate - (timezone.utcnow() - self.latest_heartbeat).total_seconds() ) sleep_for = max(0, seconds_remaining) sleep(sleep_for) # Update last heartbeat time with create_session() as session: # Make the session aware of this object session.merge(self) self.latest_heartbeat = timezone.utcnow() session.commit() # At this point, the DB has updated. previous_heartbeat = self.latest_heartbeat self.heartbeat_callback(session=session) self.log.debug('[heartbeat]') except OperationalError: Stats.incr(convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1) self.log.exception("%s heartbeat got an exception", self.__class__.__name__) # We didn't manage to heartbeat, so make sure that the timestamp isn't updated self.latest_heartbeat = previous_heartbeat
def tearDown(self) -> None: with create_session() as session: session.query(Log).delete()
def clean_up(self): with create_session() as session: session.query(DagRun).delete() session.query(TaskInstance).delete()
def clear_db_runs(): with create_session() as session: session.query(DagRun).delete() session.query(TaskInstance).delete()
def reset(dag_id): with create_session() as session: tis = session.query(TaskInstance).filter_by(dag_id=dag_id) tis.delete() runs = session.query(DagRun).filter_by(dag_id=dag_id) runs.delete()
def clear_db_dags(): with create_session() as session: session.query(DagTag).delete() session.query(DagModel).delete()
def clear_db_serialized_dags(): with create_session() as session: session.query(SDM).delete()
def clear_db_sla_miss(): with create_session() as session: session.query(SlaMiss).delete()