def setUpClass(cls): super(TestApiExperimental, cls).setUpClass() session = Session() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() session.close()
class TestChartModelView(unittest.TestCase): CREATE_ENDPOINT = '/admin/chart/new/?url=/admin/chart/' @classmethod def setUpClass(cls): super(TestChartModelView, cls).setUpClass() session = Session() session.query(models.Chart).delete() session.query(models.User).delete() session.commit() user = models.User(username='******') session.add(user) session.commit() session.close() def setUp(self): super(TestChartModelView, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.chart = { 'label': 'chart', 'owner': 'airflow', 'conn_id': 'airflow_ci', } def tearDown(self): self.session.query(models.Chart).delete() self.session.commit() self.session.close() super(TestChartModelView, self).tearDown() @classmethod def tearDownClass(cls): session = Session() session.query(models.User).delete() session.commit() session.close() super(TestChartModelView, cls).tearDownClass() def test_create_chart(self): response = self.app.post( self.CREATE_ENDPOINT, data=self.chart, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertEqual(self.session.query(models.Chart).count(), 1) def test_get_chart(self): response = self.app.get( '/admin/chart?sort=3', follow_redirects=True, ) print(response.data) self.assertEqual(response.status_code, 200) self.assertIn('Sort by Owner', response.data.decode('utf-8'))
class TestTriggerDag(unittest.TestCase): def setUp(self): conf.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() models.DagBag().get_dag("example_bash_operator").sync_to_db() def test_trigger_dag_button_normal_exist(self): resp = self.app.get('/', follow_redirects=True) self.assertIn('/trigger?dag_id=example_bash_operator', resp.data.decode('utf-8')) self.assertIn("return confirmDeleteDag('example_bash_operator')", resp.data.decode('utf-8')) def test_trigger_dag_button(self): test_dag_id = "example_bash_operator" DR = models.DagRun self.session.query(DR).delete() self.session.commit() self.app.get('/admin/airflow/trigger?dag_id={}'.format(test_dag_id)) run = self.session.query(DR).filter(DR.dag_id == test_dag_id).first() self.assertIsNotNone(run) self.assertIn("manual__", run.run_id)
def tearDown(self): configuration.test_mode() session = Session() session.query(models.User).delete() session.commit() session.close() configuration.conf.set("webserver", "authenticate", "False")
def clear_session(): """Manage airflow database state for tests""" session = Session() session.query(DagRun).delete() session.query(TI).delete() session.commit() session.close()
def tearDown(self): session = Session() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() session.close() super(TestApiExperimental, self).tearDown()
def tearDown(self): session = Session() session.query(models.TaskInstance).filter_by( dag_id=TEST_DAG_ID).delete() session.query(TaskFail).filter_by( dag_id=TEST_DAG_ID).delete() session.commit() session.close()
def setUpClass(cls): super(TestLogView, cls).setUpClass() session = Session() session.query(TaskInstance).filter( TaskInstance.dag_id == cls.DAG_ID and TaskInstance.task_id == cls.TASK_ID and TaskInstance.execution_date == cls.DEFAULT_DATE).delete() session.commit() session.close()
def tearDown(self): super(ShortCircuitOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() session.commit() session.close()
def setUpClass(cls): super(TestVarImportView, cls).setUpClass() session = Session() session.query(models.User).delete() session.commit() user = models.User(username='******') session.add(user) session.commit() session.close()
def setUpClass(cls): super(PythonOperatorTest, cls).setUpClass() session = Session() session.query(DagRun).delete() session.query(TI).delete() session.commit() session.close()
class TestVariableView(unittest.TestCase): CREATE_ENDPOINT = '/admin/variable/new/?url=/admin/variable/' @classmethod def setUpClass(cls): super(TestVariableView, cls).setUpClass() session = Session() session.query(models.Variable).delete() session.commit() session.close() def setUp(self): super(TestVariableView, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.variable = { 'key': 'test_key', 'val': 'text_val', 'is_encrypted': True } def tearDown(self): self.session.query(models.Variable).delete() self.session.commit() self.session.close() super(TestVariableView, self).tearDown() def test_can_handle_error_on_decrypt(self): # create valid variable response = self.app.post( self.CREATE_ENDPOINT, data=self.variable, follow_redirects=True, ) self.assertEqual(response.status_code, 200) # update the variable with a wrong value, given that is encrypted Var = models.Variable (self.session.query(Var) .filter(Var.key == self.variable['key']) .update({ 'val': 'failed_value_not_encrypted' }, synchronize_session=False)) self.session.commit() # retrieve Variables page, should not fail and contain the Invalid # label for the variable response = self.app.get('/admin/variable', follow_redirects=True) self.assertEqual(response.status_code, 200) self.assertEqual(self.session.query(models.Variable).count(), 1) self.assertIn('<span class="label label-danger">Invalid</span>', response.data.decode('utf-8'))
def tearDown(self): super(BranchOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() print(len(session.query(DagRun).all())) session.commit() session.close()
def setUpClass(cls): super(TestKnownEventView, cls).setUpClass() session = Session() session.query(models.KnownEvent).delete() session.query(models.User).delete() session.commit() user = models.User(username='******') session.add(user) session.commit() cls.user_id = user.id session.close()
def tearDown(self): super(PythonOperatorTest, self).tearDown() session = Session() session.query(DagRun).delete() session.query(TI).delete() print(len(session.query(DagRun).all())) session.commit() session.close() for var in TI_CONTEXT_ENV_VARS: if var in os.environ: del os.environ[var]
class TestBase(unittest.TestCase): def setUp(self): conf.load_test_config() self.app, self.appbuilder = application.create_app(testing=True) self.app.config['WTF_CSRF_ENABLED'] = False self.client = self.app.test_client() self.session = Session() self.login() def login(self): sm_session = self.appbuilder.sm.get_session() self.user = sm_session.query(ab_user).first() if not self.user: role_admin = self.appbuilder.sm.find_role('Admin') self.appbuilder.sm.add_user( username='******', first_name='test', last_name='test', email='*****@*****.**', role=role_admin, password='******') return self.client.post('/login/', data=dict( username='******', password='******' ), follow_redirects=True) def logout(self): return self.client.get('/logout/') def clear_table(self, model): self.session.query(model).delete() self.session.commit() self.session.close() def check_content_in_response(self, text, resp, resp_code=200): resp_html = resp.data.decode('utf-8') self.assertEqual(resp_code, resp.status_code) if isinstance(text, list): for kw in text: self.assertIn(kw, resp_html) else: self.assertIn(text, resp_html) def percent_encode(self, obj): if PY2: return urllib.quote_plus(str(obj)) else: return urllib.parse.quote_plus(str(obj))
def test_import_variables(self): content = ('{"str_key": "str_value", "int_key": 60,' '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}') try: # python 3+ bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) except TypeError: # python 2.7 bytes_content = io.BytesIO(bytes(content)) response = self.app.post( self.IMPORT_ENDPOINT, data={'file': (bytes_content, 'test.json')}, follow_redirects=True ) self.assertEqual(response.status_code, 200) session = Session() # Extract values from Variable db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()} session.close() self.assertIn('str_key', db_dict) self.assertIn('int_key', db_dict) self.assertIn('list_key', db_dict) self.assertIn('dict_key', db_dict) self.assertEquals('str_value', db_dict['str_key']) self.assertEquals('60', db_dict['int_key']) self.assertEquals('[1, 2]', db_dict['list_key']) case_a_dict = '{"k_a": 2, "k_b": 3}' case_b_dict = '{"k_b": 3, "k_a": 2}' try: self.assertEquals(case_a_dict, db_dict['dict_key']) except AssertionError: self.assertEquals(case_b_dict, db_dict['dict_key'])
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: 'branch_1') self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) session.close() for ti in tis: if ti.task_id == 'make_choice': self.assertEqual(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEqual(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.SKIPPED) else: raise Exception
def test_branch_list_without_dag_run(self): """This checks if the BranchPythonOperator supports branching off to a list of tasks.""" self.branch_op = BranchPythonOperator(task_id='make_choice', dag=self.dag, python_callable=lambda: ['branch_1', 'branch_2']) self.branch_1.set_upstream(self.branch_op) self.branch_2.set_upstream(self.branch_op) self.branch_3 = DummyOperator(task_id='branch_3', dag=self.dag) self.branch_3.set_upstream(self.branch_op) self.dag.clear() self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) session.close() expected = { "make_choice": State.SUCCESS, "branch_1": State.NONE, "branch_2": State.NONE, "branch_3": State.SKIPPED, } for ti in tis: if ti.task_id in expected: self.assertEqual(ti.state, expected[ti.task_id]) else: raise Exception
def test_delete_dag_button_for_dag_on_scheduler_only(self): # Test for JIRA AIRFLOW-3233 (PR 4069): # The delete-dag URL should be generated correctly for DAGs # that exist on the scheduler (DB) but not the webserver DagBag test_dag_id = "non_existent_dag" session = Session() DM = models.DagModel session.query(DM).filter(DM.dag_id == 'example_bash_operator').update({'dag_id': test_dag_id}) session.commit() resp = self.app.get('/', follow_redirects=True) self.assertIn('/delete?dag_id={}'.format(test_dag_id), resp.data.decode('utf-8')) self.assertIn("return confirmDeleteDag('{}')".format(test_dag_id), resp.data.decode('utf-8')) session.query(DM).filter(DM.dag_id == test_dag_id).update({'dag_id': 'example_bash_operator'}) session.commit()
def test_single_partition_with_templates(): """Run the entire DAG, instead of just the operator. This is necessary to instantiate the templating functionality, which requires context from the scheduler.""" bucket = "test" prefix = "dataset/v1/submission_date=20190101" client = boto3.client("s3") client.create_bucket(Bucket=bucket) client.put_object(Bucket=bucket, Body="", Key=prefix + "/part=1/_SUCCESS") client.put_object(Bucket=bucket, Body="", Key=prefix + "/part=2/garbage") dag = DAG("test_dag", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}) sensor_success = S3FSCheckSuccessSensor( task_id="test_success_template", bucket=bucket, prefix="dataset/v1/submission_date={{ ds_nodash }}/part=1", num_partitions=1, poke_interval=1, timeout=2, dag=dag, ) sensor_failure = S3FSCheckSuccessSensor( task_id="test_failure_template", bucket=bucket, prefix="dataset/v1/submission_date={{ ds_nodash }}/part=2", num_partitions=1, poke_interval=1, timeout=2, dag=dag, ) # execute everything for templating to work sensor_success.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with pytest.raises(AirflowSensorTimeout): sensor_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.execution_date == DEFAULT_DATE ) session.close() count = 0 for ti in tis: if ti.task_id == "test_success_template": assert ti.state == State.SUCCESS elif ti.task_id == "test_failure_template": assert ti.state == State.FAILED else: assert False count += 1 assert count == 2
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" value = False dag = DAG('shortcircuit_operator_test_without_dag_run', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE }, schedule_interval=INTERVAL) short_op = ShortCircuitOperator(task_id='make_choice', dag=dag, python_callable=lambda: value) branch_1 = DummyOperator(task_id='branch_1', dag=dag) branch_1.set_upstream(short_op) branch_2 = DummyOperator(task_id='branch_2', dag=dag) branch_2.set_upstream(branch_1) upstream = DummyOperator(task_id='upstream', dag=dag) upstream.set_downstream(short_op) dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE ) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': # should not exist raise elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise value = True dag.clear() short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': # should not exist raise elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.NONE) else: raise session.close()
def test_charts(self): session = Session() chart_label = "Airflow task instance by type" chart = session.query(models.Chart).filter(models.Chart.label == chart_label).first() chart_id = chart.id session.close() response = self.app.get("/admin/airflow/chart" "?chart_id={}&iteration_no=1".format(chart_id)) assert "Airflow task instance by type" in response.data.decode("utf-8") response = self.app.get("/admin/airflow/chart_data" "?chart_id={}&iteration_no=1".format(chart_id)) assert "example" in response.data.decode("utf-8") response = self.app.get("/admin/airflow/dag_details?dag_id=example_branch_operator") assert "run_this_first" in response.data.decode("utf-8")
def log(self): BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER') dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') execution_date = request.args.get('execution_date') dag = dagbag.dags[dag_id] log_relative = "/{dag_id}/{task_id}/{execution_date}".format( **locals()) loc = BASE_LOG_FOLDER + log_relative loc = loc.format(**locals()) log = "" TI = models.TaskInstance session = Session() dttm = dateutil.parser.parse(execution_date) ti = session.query(TI).filter( TI.dag_id == dag_id, TI.task_id == task_id, TI.execution_date == dttm).first() if ti: host = ti.hostname if socket.gethostname() == host: try: f = open(loc) log += "".join(f.readlines()) f.close() except: log = "Log file isn't where expected.\n".format(loc) else: WORKER_LOG_SERVER_PORT = \ conf.get('celery', 'WORKER_LOG_SERVER_PORT') url = ( "http://{host}:{WORKER_LOG_SERVER_PORT}/log" "{log_relative}").format(**locals()) log += "Log file isn't local." log += "Fetching here: {url}\n".format(**locals()) try: import urllib2 w = urllib2.urlopen(url) log += w.read() w.close() except: log += "Failed to fetch log file.".format(**locals()) session.commit() session.close() log = "<pre><code>{0}</code></pre>".format(log) title = "Logs for {task_id} on {execution_date}".format(**locals()) html_code = log return self.render( 'airflow/dag_code.html', html_code=html_code, dag=dag, title=title)
def test_charts(self): session = Session() chart_label = "Airflow task instance by type" chart = session.query( models.Chart).filter(models.Chart.label==chart_label).first() chart_id = chart.id session.close() response = self.app.get( '/admin/airflow/chart' '?chart_id={}&iteration_no=1'.format(chart_id)) assert "Airflow task instance by type" in response.data.decode('utf-8') response = self.app.get( '/admin/airflow/chart_data' '?chart_id={}&iteration_no=1'.format(chart_id)) assert "example" in response.data.decode('utf-8')
def registered(self, driver, frameworkId, masterInfo): logging.info("AirflowScheduler registered to mesos with framework ID %s", frameworkId.value) if configuration.getboolean('mesos', 'CHECKPOINT') and configuration.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error from airflow.models import Connection # Update the Framework ID in the database. session = Session() conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name() connection = Session.query(Connection).filter_by(conn_id=conn_id).first() if connection is None: connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id', extra=frameworkId.value) else: connection.extra = frameworkId.value session.add(connection) session.commit() Session.remove()
def test_import_variable_fail(self): with mock.patch('airflow.models.Variable.set') as set_mock: set_mock.side_effect = UnicodeEncodeError content = '{"fail_key": "fail_val"}' try: # python 3+ bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) except TypeError: # python 2.7 bytes_content = io.BytesIO(bytes(content)) response = self.app.post( self.IMPORT_ENDPOINT, data={'file': (bytes_content, 'test.json')}, follow_redirects=True ) self.assertEqual(response.status_code, 200) session = Session() db_dict = {x.key: x.get_val() for x in session.query(models.Variable).all()} session.close() self.assertNotIn('fail_key', db_dict)
def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( TI.dag_id == self.dag.dag_id, TI.execution_date == DEFAULT_DATE ) session.close() for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': # should exist with state None self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise
def notify(self, context=None, success=False): ts = context['ts'] dag = context['dag'] did = dag.dag_id if success: context['dagrun_status'] = 'SUCCESS' context['dagrun_class'] = 'success' else: context['dagrun_status'] = 'FAILED' context['dagrun_class'] = 'failed' context['elapsed_time'] = 'unknown' task_id = 'unknown' session = Session() try: task_id = context['task'].task_id logging.info('Context task_id {}'.format(task_id)) start_time = session.query(TaskInstance)\ .filter(TaskInstance.dag_id == did)\ .filter(TaskInstance.execution_date == ts)\ .filter(TaskInstance.start_date != None)\ .order_by(TaskInstance.start_date.asc())\ .first().start_date context['start_time'] = start_time end_time = datetime.now() context['end_time'] = end_time context['elapsed_time'] = self.td_format( end_time - start_time) if (start_time and end_time) else 'N/A' task_instances = session.query(TaskInstance)\ .filter(TaskInstance.dag_id == did)\ .filter(TaskInstance.execution_date == ts)\ .filter(TaskInstance.state != State.REMOVED)\ .order_by(TaskInstance.end_date.asc())\ .all() tis = [] for ti in task_instances: if ti.task_id == task_id: logging.info( 'Adjusting details for task_id: {}'.format(task_id)) # fix status/end/duration for the task which is causing a notification ti.end_date = end_time ti.state = 'success' if success else 'failed' if not ti.duration: # If the reporting task has no duration, make one based on the report time ti.duration = self.td_format(ti.end_date - ti.start_date) if not ti.duration: # If other tasks are still running, make duration N/A ti.duration = 'N/A' else: if not isinstance(ti.duration, str): ti.duration = self.td_format( timedelta(seconds=ti.duration)) tis.append(ti) context['task_instances'] = tis operators = sorted(list(set([op.__class__ for op in dag.tasks])), key=lambda x: x.__name__) context['operators'] = operators send_slack = self.args[ 'send_slack_message'] if 'send_slack_message' in self.args else True if send_slack: slack_message = self.message_slack_success if success else self.message_slack_fail self.slack_api_params['text'] = context[ 'task'].render_template(None, slack_message, context) self.sc.api_call('chat.postMessage', **self.slack_api_params) # don't spam email if multiple completions. spamming Slack is OK ;-) state_key = context['dag'].dag_id + '.state' dag_state = Variable.get(state_key, deserialize_json=True, default_var={}) if not dag_state.has_key('history'): dag_state['history'] = {} history = dag_state['history'] if not history.has_key(ts): history[ts] = {} date = history[ts] sent_email_key = 'sent_success_email' if success else 'sent_failure_email' if not date.has_key(sent_email_key): date[sent_email_key] = False send_multiple_failures = self.get_value_from_args( 'send_multiple_failures', False) send_success_email = self.get_value_from_args( 'send_success_emails', True) if (not success ) and date[sent_email_key] and not send_multiple_failures: logging.info( 'Skipping failure email notification because one was already sent for {0} regarding date {1}' .format(did, ts)) # nothing to do here else: subject = self.subject_success if success else self.subject_fail title = context['task'].render_template(None, subject, context) body = context['task'].render_template( None, self.message_completion(), context) email_list = context['task'].email # conditions to send an email are if task failure or # if task succeeds and user wants to receive success emails if not success or (send_success_email and success): if success: email_list = self.get_value_from_args( 'success_email', email_list) send_email(email_list, title, body) date[sent_email_key] = True Variable.set(state_key, dag_state, serialize_json=True) except Exception as e: logging.warn( 'Problem reading task state when notifying result of task: {0}' '\nException reason: {1}'.format(task_id, e)) finally: session.rollback() session.close()
class TestPoolApiExperimental(unittest.TestCase): @classmethod def setUpClass(cls): super(TestPoolApiExperimental, cls).setUpClass() session = Session() session.query(Pool).delete() session.commit() session.close() def setUp(self): super(TestPoolApiExperimental, self).setUp() configuration.load_test_config() app, _ = application.create_app(testing=True) self.app = app.test_client() self.session = Session() self.pools = [] for i in range(2): name = 'experimental_%s' % (i + 1) pool = Pool( pool=name, slots=i, description=name, ) self.session.add(pool) self.pools.append(pool) self.session.commit() self.pool = self.pools[0] def tearDown(self): self.session.query(Pool).delete() self.session.commit() self.session.close() super(TestPoolApiExperimental, self).tearDown() def _get_pool_count(self): response = self.app.get('/api/experimental/pools') self.assertEqual(response.status_code, 200) return len(json.loads(response.data.decode('utf-8'))) def test_get_pool(self): response = self.app.get( '/api/experimental/pools/{}'.format(self.pool.pool), ) self.assertEqual(response.status_code, 200) self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json()) def test_get_pool_non_existing(self): response = self.app.get('/api/experimental/pools/foo') self.assertEqual(response.status_code, 404) self.assertEqual( json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist") def test_get_pools(self): response = self.app.get('/api/experimental/pools') self.assertEqual(response.status_code, 200) pools = json.loads(response.data.decode('utf-8')) self.assertEqual(len(pools), 2) for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])): self.assertDictEqual(pool, self.pools[i].to_json()) def test_create_pool(self): response = self.app.post( '/api/experimental/pools', data=json.dumps({ 'name': 'foo', 'slots': 1, 'description': '', }), content_type='application/json', ) self.assertEqual(response.status_code, 200) pool = json.loads(response.data.decode('utf-8')) self.assertEqual(pool['pool'], 'foo') self.assertEqual(pool['slots'], 1) self.assertEqual(pool['description'], '') self.assertEqual(self._get_pool_count(), 3) def test_create_pool_with_bad_name(self): for name in ('', ' '): response = self.app.post( '/api/experimental/pools', data=json.dumps({ 'name': name, 'slots': 1, 'description': '', }), content_type='application/json', ) self.assertEqual(response.status_code, 400) self.assertEqual( json.loads(response.data.decode('utf-8'))['error'], "Pool name shouldn't be empty", ) self.assertEqual(self._get_pool_count(), 2) def test_delete_pool(self): response = self.app.delete( '/api/experimental/pools/{}'.format(self.pool.pool), ) self.assertEqual(response.status_code, 200) self.assertEqual(json.loads(response.data.decode('utf-8')), self.pool.to_json()) self.assertEqual(self._get_pool_count(), 1) def test_delete_pool_non_existing(self): response = self.app.delete('/api/experimental/pools/foo', ) self.assertEqual(response.status_code, 404) self.assertEqual( json.loads(response.data.decode('utf-8'))['error'], "Pool 'foo' doesn't exist")
class TestLogView(unittest.TestCase): DAG_ID = 'dag_for_testing_log_view' TASK_ID = 'task_for_testing_log_view' DEFAULT_DATE = datetime(2017, 9, 1) ENDPOINT = '/admin/airflow/log?dag_id={dag_id}&task_id={task_id}&execution_date={execution_date}'.format( dag_id=DAG_ID, task_id=TASK_ID, execution_date=DEFAULT_DATE, ) @classmethod def setUpClass(cls): super(TestLogView, cls).setUpClass() session = Session() session.query(TaskInstance).filter( TaskInstance.dag_id == cls.DAG_ID and TaskInstance.task_id == cls.TASK_ID and TaskInstance.execution_date == cls.DEFAULT_DATE).delete() session.commit() session.close() def setUp(self): super(TestLogView, self).setUp() # Create a custom logging configuration configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit() def tearDown(self): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) self.session.query(TaskInstance).filter( TaskInstance.dag_id == self.DAG_ID and TaskInstance.task_id == self.TASK_ID and TaskInstance.execution_date == self.DEFAULT_DATE).delete() self.session.commit() self.session.close() sys.path.remove(self.settings_folder) shutil.rmtree(self.settings_folder) conf.set('core', 'logging_config_class', '') super(TestLogView, self).tearDown() def test_get_file_task_log(self): response = self.app.get( TestLogView.ENDPOINT, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertIn('Log by attempts', response.data.decode('utf-8')) def test_get_logs_with_metadata_as_download_file(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata={}&format=file" try_number = 1 url = url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), try_number, json.dumps({})) response = self.app.get(url) expected_filename = '{}/{}/{}/{}.log'.format( self.DAG_ID, self.TASK_ID, self.DEFAULT_DATE.isoformat(), try_number) content_disposition = response.headers.get('Content-Disposition') self.assertTrue(content_disposition.startswith('attachment')) self.assertTrue(content_disposition.endswith(expected_filename)) self.assertEqual(200, response.status_code) self.assertIn('Log for testing.', response.data.decode('utf-8')) def test_get_logs_with_metadata(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata={}" response = \ self.app.get(url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), 1, json.dumps({}))) self.assertIn('"message":', response.data.decode('utf-8')) self.assertIn('"metadata":', response.data.decode('utf-8')) self.assertIn('Log for testing.', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) def test_get_logs_with_null_metadata(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata=null" response = \ self.app.get(url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), 1)) self.assertIn('"message":', response.data.decode('utf-8')) self.assertIn('"metadata":', response.data.decode('utf-8')) self.assertIn('Log for testing.', response.data.decode('utf-8')) self.assertEqual(200, response.status_code)
def tearDown(self): session = Session() session.query(DagRun).delete() session.commit() session.close() super().tearDown()
class ViewWithDateTimeAndNumRunsAndDagRunsFormTester: DAG_ID = 'dag_for_testing_dt_nr_dr_form' DEFAULT_DATE = datetime(2017, 9, 1) RUNS_DATA = [ ('dag_run_for_testing_dt_nr_dr_form_4', datetime(2018, 4, 4)), ('dag_run_for_testing_dt_nr_dr_form_3', datetime(2018, 3, 3)), ('dag_run_for_testing_dt_nr_dr_form_2', datetime(2018, 2, 2)), ('dag_run_for_testing_dt_nr_dr_form_1', datetime(2018, 1, 1)), ] def __init__(self, test, endpoint): self.test = test self.endpoint = endpoint def setUp(self): configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun( run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True ) self.runs.append(run) def tearDown(self): self.session.query(DagRun).filter( DagRun.dag_id == self.DAG_ID).delete() self.session.commit() self.session.close() def assertBaseDateAndNumRuns(self, base_date, num_runs, data): self.test.assertNotIn('name="base_date" value="{}"'.format(base_date), data) self.test.assertNotIn('<option selected="" value="{}">{}</option>'.format( num_runs, num_runs), data) def assertRunIsNotInDropdown(self, run, data): self.test.assertNotIn(run.execution_date.isoformat(), data) self.test.assertNotIn(run.run_id, data) def assertRunIsInDropdownNotSelected(self, run, data): self.test.assertIn('<option value="{}">{}</option>'.format( run.execution_date.isoformat(), run.run_id), data) def assertRunIsSelected(self, run, data): self.test.assertIn('<option selected value="{}">{}</option>'.format( run.execution_date.isoformat(), run.run_id), data) def test_with_default_parameters(self): """ Tests graph view with no URL parameter. Should show all dag runs in the drop down. Should select the latest dag run. Should set base date to current date (not asserted) """ response = self.app.get( self.endpoint ) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.test.assertIn('Base date:', data) self.test.assertIn('Number of runs:', data) self.assertRunIsSelected(self.runs[0], data) self.assertRunIsInDropdownNotSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_execution_date_parameter_only(self): """ Tests graph view with execution_date URL parameter. Scenario: click link from dag runs view. Should only show dag runs older than execution_date in the drop down. Should select the particular dag run. Should set base date to execution date. """ response = self.app.get( self.endpoint + '&execution_date={}'.format( self.runs[1].execution_date.isoformat()) ) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns( self.runs[1].execution_date, configuration.getint('webserver', 'default_dag_run_display_number'), data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_base_date_and_num_runs_parmeters_only(self): """ Tests graph view with base_date and num_runs URL parameters. Should only show dag runs older than base_date in the drop down, limited to num_runs. Should select the latest dag run. Should set base date and num runs to submitted values. """ response = self.app.get( self.endpoint + '&base_date={}&num_runs=2'.format( self.runs[1].execution_date.isoformat()) ) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 2, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsNotInDropdown(self.runs[3], data) def test_with_base_date_and_num_runs_and_execution_date_outside(self): """ Tests graph view with base_date and num_runs and execution-date URL parameters. Scenario: change the base date and num runs and press "Go", the selected execution date is outside the new range. Should only show dag runs older than base_date in the drop down. Should select the latest dag run within the range. Should set base date and num runs to submitted values. """ response = self.app.get( self.endpoint + '&base_date={}&num_runs=42&execution_date={}'.format( self.runs[1].execution_date.isoformat(), self.runs[0].execution_date.isoformat()) ) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 42, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_base_date_and_num_runs_and_execution_date_within(self): """ Tests graph view with base_date and num_runs and execution-date URL parameters. Scenario: change the base date and num runs and press "Go", the selected execution date is within the new range. Should only show dag runs older than base_date in the drop down. Should select the dag run with the execution date. Should set base date and num runs to submitted values. """ response = self.app.get( self.endpoint + '&base_date={}&num_runs=5&execution_date={}'.format( self.runs[2].execution_date.isoformat(), self.runs[3].execution_date.isoformat()) ) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[2].execution_date, 5, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsNotInDropdown(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsSelected(self.runs[3], data)
class TestMarkTasks(unittest.TestCase): def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag1 = self.dagbag.dags['test_example_bash_operator'] self.dag2 = self.dagbag.dags['example_subdag_operator'] self.execution_dates = [days_ago(2), days_ago(1)] drs = _create_dagruns(self.dag1, self.execution_dates, state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag1 dr.verify_integrity() drs = _create_dagruns(self.dag2, [self.dag2.default_args['start_date']], state=State.RUNNING, run_id_template="scheduled__{}") for dr in drs: dr.dag = self.dag2 dr.verify_integrity() self.session = Session() def snapshot_state(self, dag, execution_dates): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all() self.session.expunge_all() return tis def verify_state(self, dag, task_ids, execution_dates, state, old_tis): TI = models.TaskInstance tis = self.session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(execution_dates)).all() self.assertTrue(len(tis) > 0) for ti in tis: if ti.task_id in task_ids and ti.execution_date in execution_dates: self.assertEqual(ti.state, state) else: for old_ti in old_tis: if (old_ti.task_id == ti.task_id and old_ti.execution_date == ti.execution_date): self.assertEqual(ti.state, old_ti.state) def test_mark_tasks_now(self): # set one task to success but do not commit snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], None, snapshot) # set one and only one task to success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set no tasks altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 0) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) # set task to other than success altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.FAILED, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.FAILED, snapshot) # dont alter other tasks snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_0") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 1) self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_downstream(self): # test downstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 3) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_upstream(self): # test upstream snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("run_after_loop") relatives = task.get_flat_relatives(upstream=True) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=True, downstream=False, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 4) self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], State.SUCCESS, snapshot) def test_mark_tasks_future(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=False, future=True, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) def test_mark_tasks_past(self): # set one task to success towards end of scheduled dag runs snapshot = self.snapshot_state(self.dag1, self.execution_dates) task = self.dag1.get_task("runme_1") altered = set_state(task=task, execution_date=self.execution_dates[1], upstream=False, downstream=False, future=False, past=True, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 2) self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot) def test_mark_tasks_subdag(self): # set one task to success towards end of scheduled dag runs task = self.dag2.get_task("section-1") relatives = task.get_flat_relatives(upstream=False) task_ids = [t.task_id for t in relatives] task_ids.append(task.task_id) altered = set_state(task=task, execution_date=self.execution_dates[0], upstream=False, downstream=True, future=False, past=False, state=State.SUCCESS, commit=True) self.assertEqual(len(altered), 14) # cannot use snapshot here as that will require drilling down the # the sub dag tree essentially recreating the same code as in the # tested logic. self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, []) def tearDown(self): self.dag1.clear() self.dag2.clear() # just to make sure we are fully cleaned up self.session.query(models.DagRun).delete() self.session.query(models.TaskInstance).delete() self.session.commit() self.session.close()
def reset(dag_id): session = Session() tis = session.query(TaskInstance).filter_by(dag_id=dag_id) tis.delete() session.commit() session.close()
class TestLogView(unittest.TestCase): DAG_ID = 'dag_for_testing_log_view' TASK_ID = 'task_for_testing_log_view' DEFAULT_DATE = datetime(2017, 9, 1) ENDPOINT = '/admin/airflow/log?dag_id={dag_id}&task_id={task_id}&execution_date={execution_date}'.format( dag_id=DAG_ID, task_id=TASK_ID, execution_date=DEFAULT_DATE, ) @classmethod def setUpClass(cls): super(TestLogView, cls).setUpClass() session = Session() session.query(TaskInstance).filter( TaskInstance.dag_id == cls.DAG_ID and TaskInstance.task_id == cls.TASK_ID and TaskInstance.execution_date == cls.DEFAULT_DATE).delete() session.commit() session.close() def setUp(self): super(TestLogView, self).setUp() # Create a custom logging configuration configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit() def tearDown(self): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) dagbag = models.DagBag(settings.DAGS_FOLDER) self.session.query(TaskInstance).filter( TaskInstance.dag_id == self.DAG_ID and TaskInstance.task_id == self.TASK_ID and TaskInstance.execution_date == self.DEFAULT_DATE).delete() self.session.commit() self.session.close() sys.path.remove(self.settings_folder) shutil.rmtree(self.settings_folder) conf.set('core', 'logging_config_class', '') super(TestLogView, self).tearDown() def test_get_file_task_log(self): response = self.app.get( TestLogView.ENDPOINT, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertIn('Log file does not exist', response.data.decode('utf-8'))
class ViewWithDateTimeAndNumRunsAndDagRunsFormTester: DAG_ID = 'dag_for_testing_dt_nr_dr_form' DEFAULT_DATE = datetime(2017, 9, 1) RUNS_DATA = [ ('dag_run_for_testing_dt_nr_dr_form_4', datetime(2018, 4, 4)), ('dag_run_for_testing_dt_nr_dr_form_3', datetime(2018, 3, 3)), ('dag_run_for_testing_dt_nr_dr_form_2', datetime(2018, 2, 2)), ('dag_run_for_testing_dt_nr_dr_form_1', datetime(2018, 1, 1)), ] def __init__(self, test, endpoint): self.test = test self.endpoint = endpoint def setUp(self): configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag from airflow.utils.state import State dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) self.runs = [] for rd in self.RUNS_DATA: run = dag.create_dagrun(run_id=rd[0], execution_date=rd[1], state=State.SUCCESS, external_trigger=True) self.runs.append(run) def tearDown(self): self.session.query(DagRun).filter( DagRun.dag_id == self.DAG_ID).delete() self.session.commit() self.session.close() def assertBaseDateAndNumRuns(self, base_date, num_runs, data): self.test.assertNotIn('name="base_date" value="{}"'.format(base_date), data) self.test.assertNotIn( '<option selected="" value="{}">{}</option>'.format( num_runs, num_runs), data) def assertRunIsNotInDropdown(self, run, data): self.test.assertNotIn(run.execution_date.isoformat(), data) self.test.assertNotIn(run.run_id, data) def assertRunIsInDropdownNotSelected(self, run, data): self.test.assertIn( '<option value="{}">{}</option>'.format( run.execution_date.isoformat(), run.run_id), data) def assertRunIsSelected(self, run, data): self.test.assertIn( '<option selected value="{}">{}</option>'.format( run.execution_date.isoformat(), run.run_id), data) def test_with_default_parameters(self): """ Tests graph view with no URL parameter. Should show all dag runs in the drop down. Should select the latest dag run. Should set base date to current date (not asserted) """ response = self.app.get(self.endpoint) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.test.assertIn('Base date:', data) self.test.assertIn('Number of runs:', data) self.assertRunIsSelected(self.runs[0], data) self.assertRunIsInDropdownNotSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_execution_date_parameter_only(self): """ Tests graph view with execution_date URL parameter. Scenario: click link from dag runs view. Should only show dag runs older than execution_date in the drop down. Should select the particular dag run. Should set base date to execution date. """ response = self.app.get(self.endpoint + '&execution_date={}'.format( self.runs[1].execution_date.isoformat())) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns( self.runs[1].execution_date, configuration.getint('webserver', 'default_dag_run_display_number'), data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_base_date_and_num_runs_parmeters_only(self): """ Tests graph view with base_date and num_runs URL parameters. Should only show dag runs older than base_date in the drop down, limited to num_runs. Should select the latest dag run. Should set base date and num runs to submitted values. """ response = self.app.get(self.endpoint + '&base_date={}&num_runs=2'.format( self.runs[1].execution_date.isoformat())) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 2, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsNotInDropdown(self.runs[3], data) def test_with_base_date_and_num_runs_and_execution_date_outside(self): """ Tests graph view with base_date and num_runs and execution-date URL parameters. Scenario: change the base date and num runs and press "Go", the selected execution date is outside the new range. Should only show dag runs older than base_date in the drop down. Should select the latest dag run within the range. Should set base date and num runs to submitted values. """ response = self.app.get(self.endpoint + '&base_date={}&num_runs=42&execution_date={}'. format(self.runs[1].execution_date.isoformat( ), self.runs[0].execution_date.isoformat())) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[1].execution_date, 42, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsSelected(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsInDropdownNotSelected(self.runs[3], data) def test_with_base_date_and_num_runs_and_execution_date_within(self): """ Tests graph view with base_date and num_runs and execution-date URL parameters. Scenario: change the base date and num runs and press "Go", the selected execution date is within the new range. Should only show dag runs older than base_date in the drop down. Should select the dag run with the execution date. Should set base date and num runs to submitted values. """ response = self.app.get(self.endpoint + '&base_date={}&num_runs=5&execution_date={}'. format(self.runs[2].execution_date.isoformat( ), self.runs[3].execution_date.isoformat())) self.test.assertEqual(response.status_code, 200) data = response.data.decode('utf-8') self.assertBaseDateAndNumRuns(self.runs[2].execution_date, 5, data) self.assertRunIsNotInDropdown(self.runs[0], data) self.assertRunIsNotInDropdown(self.runs[1], data) self.assertRunIsInDropdownNotSelected(self.runs[2], data) self.assertRunIsSelected(self.runs[3], data)
class TestLogView(unittest.TestCase): DAG_ID = 'dag_for_testing_log_view' TASK_ID = 'task_for_testing_log_view' DEFAULT_DATE = datetime(2017, 9, 1) ENDPOINT = '/admin/airflow/log?dag_id={dag_id}&task_id={task_id}&execution_date={execution_date}'.format( dag_id=DAG_ID, task_id=TASK_ID, execution_date=DEFAULT_DATE, ) @classmethod def setUpClass(cls): super(TestLogView, cls).setUpClass() session = Session() session.query(TaskInstance).filter( TaskInstance.dag_id == cls.DAG_ID and TaskInstance.task_id == cls.TASK_ID and TaskInstance.execution_date == cls.DEFAULT_DATE).delete() session.commit() session.close() def setUp(self): super(TestLogView, self).setUp() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) # Create a custom logging configuration logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging_config['handlers']['task']['filename_template'] = \ '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log' # Write the custom logging configuration to a file self.settings_folder = tempfile.mkdtemp() settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py") new_logging_file = "LOGGING_CONFIG = {}".format(logging_config) with open(settings_file, 'w') as handle: handle.writelines(new_logging_file) sys.path.append(self.settings_folder) conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG') app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit() def tearDown(self): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) self.session.query(TaskInstance).filter( TaskInstance.dag_id == self.DAG_ID and TaskInstance.task_id == self.TASK_ID and TaskInstance.execution_date == self.DEFAULT_DATE).delete() self.session.commit() self.session.close() # Remove any new modules imported during the test run. This lets us # import the same source files for more than one test. for m in [m for m in sys.modules if m not in self.old_modules]: del sys.modules[m] sys.path.remove(self.settings_folder) shutil.rmtree(self.settings_folder) conf.set('core', 'logging_config_class', '') super(TestLogView, self).tearDown() def test_get_file_task_log(self): response = self.app.get( TestLogView.ENDPOINT, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertIn('Log by attempts', response.data.decode('utf-8')) def test_get_logs_with_metadata_as_download_file(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata={}&format=file" try_number = 1 url = url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), try_number, json.dumps({})) response = self.app.get(url) expected_filename = '{}/{}/{}/{}.log'.format( self.DAG_ID, self.TASK_ID, self.DEFAULT_DATE.isoformat(), try_number) content_disposition = response.headers.get('Content-Disposition') self.assertTrue(content_disposition.startswith('attachment')) self.assertTrue(expected_filename in content_disposition) self.assertEqual(200, response.status_code) self.assertIn('Log for testing.', response.data.decode('utf-8')) def test_get_logs_with_metadata_as_download_large_file(self): with mock.patch( "airflow.utils.log.file_task_handler.FileTaskHandler.read" ) as read_mock: first_return = (['1st line'], [{}]) second_return = (['2nd line'], [{'end_of_log': False}]) third_return = (['3rd line'], [{'end_of_log': True}]) fourth_return = (['should never be read'], [{'end_of_log': True}]) read_mock.side_effect = [ first_return, second_return, third_return, fourth_return ] url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata={}&format=file" try_number = 1 url = url_template.format( self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), try_number, json.dumps({})) response = self.app.get(url) self.assertIn('1st line', response.data.decode('utf-8')) self.assertIn('2nd line', response.data.decode('utf-8')) self.assertIn('3rd line', response.data.decode('utf-8')) self.assertNotIn('should never be read', response.data.decode('utf-8')) def test_get_logs_with_metadata(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata={}" response = \ self.app.get(url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), 1, json.dumps({}))) self.assertIn('"message":', response.data.decode('utf-8')) self.assertIn('"metadata":', response.data.decode('utf-8')) self.assertIn('Log for testing.', response.data.decode('utf-8')) self.assertEqual(200, response.status_code) def test_get_logs_with_null_metadata(self): url_template = "/admin/airflow/get_logs_with_metadata?dag_id={}&" \ "task_id={}&execution_date={}&" \ "try_number={}&metadata=null" response = \ self.app.get(url_template.format(self.DAG_ID, self.TASK_ID, quote_plus(self.DEFAULT_DATE.isoformat()), 1)) self.assertIn('"message":', response.data.decode('utf-8')) self.assertIn('"metadata":', response.data.decode('utf-8')) self.assertIn('Log for testing.', response.data.decode('utf-8')) self.assertEqual(200, response.status_code)
def test_dag_retry_limit_causes_premature_failure(): """Verify that that setting a retry limit on the DAG breaks the sensor functionality. We want to make sure that the sensors are allowed to retry until they timeout.""" bucket = "test" prefix = "dataset/v1/submission_date=20190101" client = boto3.client("s3") client.create_bucket(Bucket=bucket) dag = DAG( "test_dag_retries", default_args={ "owner": "airflow", "start_date": DEFAULT_DATE, "retries": 1, "retry_delay": timedelta(seconds=1), }, ) sensor_retry = S3FSCheckSuccessSensor( task_id="test_retry_template", bucket=bucket, prefix="dataset/v1/submission_date={{ ds_nodash }}", num_partitions=1, poke_interval=1, timeout=2, dag=dag, ) sensor_failure = S3FSCheckSuccessSensor( task_id="test_failure_template", bucket=bucket, prefix="dataset/v1/submission_date={{ ds_nodash }}", num_partitions=1, poke_interval=1, timeout=2, retries=0, # disable retries for the correct behavior dag=dag, ) # execute everything for templating to work with pytest.raises(AirflowSensorTimeout): sensor_retry.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) with pytest.raises(AirflowSensorTimeout): sensor_failure.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TaskInstance).filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.execution_date == DEFAULT_DATE) session.close() count = 0 for ti in tis: if ti.task_id == "test_retry_template": assert ti.state == State.UP_FOR_RETRY elif ti.task_id == "test_failure_template": assert ti.state == State.FAILED else: print(ti.task_id, ti.state) assert False count += 1 assert count == 2
def setUpClass(cls): super(TestPoolApiExperimental, cls).setUpClass() session = Session() session.query(Pool).delete() session.commit() session.close()
def set_state(task, execution_date, upstream=False, downstream=False, future=False, past=False, state=State.SUCCESS, commit=False): """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively for past tasks. Will verify integrity of past dag runs in order to create tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). :param task: the task from which to work. task.task.dag needs to be set :param execution_date: the execution date from which to start looking :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags :param future: Mark all future tasks on the interval of the dag up until last execution date. :param past: Retroactively mark all tasks starting from start_date of the DAG :param state: State to which the tasks need to be set :param commit: Commit tasks to be altered to the database :return: list of tasks that have been created and updated """ assert timezone.is_localized(execution_date) # microseconds are supported by the database, but is not handled # correctly by airflow on e.g. the filesystem and in other places execution_date = execution_date.replace(microsecond=0) assert task.dag is not None dag = task.dag latest_execution_date = dag.latest_execution_date assert latest_execution_date is not None # determine date range of dag runs and tasks to consider end_date = latest_execution_date if future else execution_date if 'start_date' in dag.default_args: start_date = dag.default_args['start_date'] elif dag.start_date: start_date = dag.start_date else: start_date = execution_date start_date = execution_date if not past else start_date if dag.schedule_interval == '@once': dates = [start_date] else: dates = dag.date_range(start_date=start_date, end_date=end_date) # find relatives (siblings = downstream, parents = upstream) if needed task_ids = [task.task_id] if downstream: relatives = task.get_flat_relatives(upstream=False) task_ids += [t.task_id for t in relatives] if upstream: relatives = task.get_flat_relatives(upstream=True) task_ids += [t.task_id for t in relatives] # verify the integrity of the dag runs in case a task was added or removed # set the confirmed execution dates as they might be different # from what was provided confirmed_dates = [] drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) for dr in drs: dr.dag = dag dr.verify_integrity() confirmed_dates.append(dr.execution_date) # go through subdagoperators and create dag runs. We will only work # within the scope of the subdag. We wont propagate to the parent dag, # but we will propagate from parent to subdag. session = Session() dags = [dag] sub_dag_ids = [] while len(dags) > 0: current_dag = dags.pop() for task_id in task_ids: if not current_dag.has_task(task_id): continue current_task = current_dag.get_task(task_id) if isinstance(current_task, SubDagOperator): # this works as a kind of integrity check # it creates missing dag runs for subdagoperators, # maybe this should be moved to dagrun.verify_integrity drs = _create_dagruns(current_task.subdag, execution_dates=confirmed_dates, state=State.RUNNING, run_id_template=BackfillJob.ID_FORMAT_PREFIX) for dr in drs: dr.dag = current_task.subdag dr.verify_integrity() if commit: dr.state = state session.merge(dr) dags.append(current_task.subdag) sub_dag_ids.append(current_task.subdag.dag_id) # now look for the task instances that are affected TI = TaskInstance # get all tasks of the main dag that will be affected by a state change qry_dag = session.query(TI).filter( TI.dag_id == dag.dag_id, TI.execution_date.in_(confirmed_dates), TI.task_id.in_(task_ids)).filter( or_(TI.state.is_(None), TI.state != state) ) # get *all* tasks of the sub dags if len(sub_dag_ids) > 0: qry_sub_dag = session.query(TI).filter( TI.dag_id.in_(sub_dag_ids), TI.execution_date.in_(confirmed_dates)).filter( or_(TI.state.is_(None), TI.state != state) ) if commit: tis_altered = qry_dag.with_for_update().all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.with_for_update().all() for ti in tis_altered: ti.state = state session.commit() else: tis_altered = qry_dag.all() if len(sub_dag_ids) > 0: tis_altered += qry_sub_dag.all() session.expunge_all() session.close() return tis_altered
def tearDownClass(cls): session = Session() session.query(models.User).delete() session.commit() session.close() super(TestKnownEventView, cls).tearDownClass()
def start(self): self.task_queue = Queue() self.result_queue = Queue() framework = mesos_pb2.FrameworkInfo() framework.user = '' if not configuration.conf.get('mesos', 'MASTER'): self.log.error("Expecting mesos master URL for mesos executor") raise AirflowException("mesos.master not provided for mesos executor") master = configuration.conf.get('mesos', 'MASTER') framework.name = get_framework_name() if not configuration.conf.get('mesos', 'TASK_CPU'): task_cpu = 1 else: task_cpu = configuration.conf.getint('mesos', 'TASK_CPU') if not configuration.conf.get('mesos', 'TASK_MEMORY'): task_memory = 256 else: task_memory = configuration.conf.getint('mesos', 'TASK_MEMORY') if configuration.conf.getboolean('mesos', 'CHECKPOINT'): framework.checkpoint = True if configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'): # Import here to work around a circular import error from airflow.models import Connection # Query the database to get the ID of the Mesos Framework, if available. conn_id = FRAMEWORK_CONNID_PREFIX + framework.name session = Session() connection = session.query(Connection).filter_by(conn_id=conn_id).first() if connection is not None: # Set the Framework ID to let the scheduler reconnect with running tasks. framework.id.value = connection.extra framework.failover_timeout = configuration.conf.getint( 'mesos', 'FAILOVER_TIMEOUT' ) else: framework.checkpoint = False self.log.info( 'MesosFramework master : %s, name : %s, cpu : %s, mem : %s, checkpoint : %s', master, framework.name, str(task_cpu), str(task_memory), str(framework.checkpoint) ) implicit_acknowledgements = 1 if configuration.conf.getboolean('mesos', 'AUTHENTICATE'): if not configuration.conf.get('mesos', 'DEFAULT_PRINCIPAL'): self.log.error("Expecting authentication principal in the environment") raise AirflowException("mesos.default_principal not provided in authenticated mode") if not configuration.conf.get('mesos', 'DEFAULT_SECRET'): self.log.error("Expecting authentication secret in the environment") raise AirflowException("mesos.default_secret not provided in authenticated mode") credential = mesos_pb2.Credential() credential.principal = configuration.conf.get('mesos', 'DEFAULT_PRINCIPAL') credential.secret = configuration.conf.get('mesos', 'DEFAULT_SECRET') framework.principal = credential.principal driver = mesos.native.MesosSchedulerDriver( AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory), framework, master, implicit_acknowledgements, credential) else: framework.principal = 'Airflow' driver = mesos.native.MesosSchedulerDriver( AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory), framework, master, implicit_acknowledgements) self.mesos_driver = driver self.mesos_driver.start()
class TestPoolModelView(unittest.TestCase): CREATE_ENDPOINT = '/admin/pool/new/?url=/admin/pool/' @classmethod def setUpClass(cls): super(TestPoolModelView, cls).setUpClass() session = Session() session.query(models.Pool).delete() session.commit() session.close() def setUp(self): super(TestPoolModelView, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.pool = { 'pool': 'test-pool', 'slots': 777, 'description': 'test-pool-description', } def tearDown(self): self.session.query(models.Pool).delete() self.session.commit() self.session.close() super(TestPoolModelView, self).tearDown() def test_create_pool(self): response = self.app.post( self.CREATE_ENDPOINT, data=self.pool, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertEqual(self.session.query(models.Pool).count(), 1) def test_create_pool_with_same_name(self): # create test pool self.app.post( self.CREATE_ENDPOINT, data=self.pool, follow_redirects=True, ) # create pool with the same name response = self.app.post( self.CREATE_ENDPOINT, data=self.pool, follow_redirects=True, ) self.assertIn('Already exists.', response.data.decode('utf-8')) self.assertEqual(self.session.query(models.Pool).count(), 1) def test_create_pool_with_empty_name(self): self.pool['pool'] = '' response = self.app.post( self.CREATE_ENDPOINT, data=self.pool, follow_redirects=True, ) self.assertIn('This field is required.', response.data.decode('utf-8')) self.assertEqual(self.session.query(models.Pool).count(), 0)
class TestLogView(unittest.TestCase): DAG_ID = 'dag_for_testing_log_view' TASK_ID = 'task_for_testing_log_view' DEFAULT_DATE = datetime(2017, 9, 1) ENDPOINT = '/admin/airflow/log?dag_id={dag_id}&task_id={task_id}&execution_date={execution_date}'.format( dag_id=DAG_ID, task_id=TASK_ID, execution_date=DEFAULT_DATE, ) @classmethod def setUpClass(cls): super(TestLogView, cls).setUpClass() session = Session() session.query(TaskInstance).filter( TaskInstance.dag_id == cls.DAG_ID and TaskInstance.task_id == cls.TASK_ID and TaskInstance.execution_date == cls.DEFAULT_DATE).delete() session.commit() session.close() def setUp(self): super(TestLogView, self).setUp() configuration.load_test_config() logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG) current_dir = os.path.dirname(os.path.abspath(__file__)) logging_config['handlers']['file.task'][ 'base_log_folder'] = os.path.normpath( os.path.join(current_dir, 'test_logs')) logging.config.dictConfig(logging_config) app = application.create_app(testing=True) self.app = app.test_client() self.session = Session() from airflow.www.views import dagbag dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE) task = DummyOperator(task_id=self.TASK_ID, dag=dag) dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag) ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE) ti.try_number = 1 self.session.merge(ti) self.session.commit() def tearDown(self): logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) dagbag = models.DagBag(settings.DAGS_FOLDER) self.session.query(TaskInstance).filter( TaskInstance.dag_id == self.DAG_ID and TaskInstance.task_id == self.TASK_ID and TaskInstance.execution_date == self.DEFAULT_DATE).delete() self.session.commit() self.session.close() super(TestLogView, self).tearDown() def test_get_file_task_log(self): response = self.app.get( TestLogView.ENDPOINT, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertIn( '<pre id="attempt-1">*** Reading local log.\nLog for testing.\n</pre>', response.data.decode('utf-8'))
def setUpClass(cls): super(TestPoolModelView, cls).setUpClass() session = Session() session.query(models.Pool).delete() session.commit() session.close()
class TestMarkDAGRun(unittest.TestCase): def setUp(self): self.dagbag = models.DagBag(include_examples=True) self.dag1 = self.dagbag.dags['example_bash_operator'] self.dag2 = self.dagbag.dags['example_subdag_operator'] self.execution_dates = [days_ago(2), days_ago(1), days_ago(0)] self.session = Session() def _set_default_task_instance_states(self, dr): if dr.dag_id != 'example_bash_operator': return # success task dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session) # skipped task dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session) # retry task dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session) # queued task dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session) # running task dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session) # failed task dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session) def _verify_task_instance_states_remain_default(self, dr): self.assertEqual(dr.get_task_instance('runme_0').state, State.SUCCESS) self.assertEqual(dr.get_task_instance('runme_1').state, State.SKIPPED) self.assertEqual(dr.get_task_instance('runme_2').state, State.UP_FOR_RETRY) self.assertEqual(dr.get_task_instance('also_run_this').state, State.QUEUED, ) self.assertEqual(dr.get_task_instance('run_after_loop').state, State.RUNNING) self.assertEqual(dr.get_task_instance('run_this_last').state, State.FAILED) def _verify_task_instance_states(self, dag, date, state): TI = models.TaskInstance tis = self.session.query(TI).filter(TI.dag_id == dag.dag_id, TI.execution_date == date) for ti in tis: self.assertEqual(ti.state, state) def _create_test_dag_run(self, state, date): return self.dag1.create_dagrun( run_id='manual__' + datetime.now().isoformat(), state=state, execution_date=date, session=self.session ) def _verify_dag_run_state(self, dag, date, state): drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) dr = drs[0] self.assertEqual(dr.get_state(), state) def _verify_dag_run_dates(self, dag, date, state, middle_time): # When target state is RUNNING, we should set start_date, # otherwise we should set end_date. drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) dr = drs[0] if state == State.RUNNING: self.assertGreater(dr.start_date, middle_time) self.assertIsNone(dr.end_date) else: self.assertLess(dr.start_date, middle_time) self.assertGreater(dr.end_date, middle_time) def test_set_running_dag_run_to_success(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.RUNNING, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_success(self.dag1, date, commit=True) # All except the SUCCESS task should be altered. self.assertEqual(len(altered), 5) self._verify_dag_run_state(self.dag1, date, State.SUCCESS) self._verify_task_instance_states(self.dag1, date, State.SUCCESS) self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time) def test_set_running_dag_run_to_failed(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.RUNNING, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(self.dag1, date, commit=True) # Only running task should be altered. self.assertEqual(len(altered), 1) self._verify_dag_run_state(self.dag1, date, State.FAILED) self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED) self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time) def test_set_running_dag_run_to_running(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.RUNNING, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_running(self.dag1, date, commit=True) # None of the tasks should be altered. self.assertEqual(len(altered), 0) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time) def test_set_success_dag_run_to_success(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_success(self.dag1, date, commit=True) # All except the SUCCESS task should be altered. self.assertEqual(len(altered), 5) self._verify_dag_run_state(self.dag1, date, State.SUCCESS) self._verify_task_instance_states(self.dag1, date, State.SUCCESS) self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time) def test_set_success_dag_run_to_failed(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(self.dag1, date, commit=True) # Only running task should be altered. self.assertEqual(len(altered), 1) self._verify_dag_run_state(self.dag1, date, State.FAILED) self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED) self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time) def test_set_success_dag_run_to_running(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_running(self.dag1, date, commit=True) # None of the tasks should be altered. self.assertEqual(len(altered), 0) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time) def test_set_failed_dag_run_to_success(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_success(self.dag1, date, commit=True) # All except the SUCCESS task should be altered. self.assertEqual(len(altered), 5) self._verify_dag_run_state(self.dag1, date, State.SUCCESS) self._verify_task_instance_states(self.dag1, date, State.SUCCESS) self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time) def test_set_failed_dag_run_to_failed(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(self.dag1, date, commit=True) # Only running task should be altered. self.assertEqual(len(altered), 1) self._verify_dag_run_state(self.dag1, date, State.FAILED) self.assertEqual(dr.get_task_instance('run_after_loop').state, State.FAILED) self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time) def test_set_failed_dag_run_to_running(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.SUCCESS, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_running(self.dag1, date, commit=True) # None of the tasks should be altered. self.assertEqual(len(altered), 0) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) self._verify_dag_run_dates(self.dag1, date, State.RUNNING, middle_time) def test_set_state_without_commit(self): date = self.execution_dates[0] dr = self._create_test_dag_run(State.RUNNING, date) self._set_default_task_instance_states(dr) will_be_altered = set_dag_run_state_to_running(self.dag1, date, commit=False) # None of the tasks will be altered. self.assertEqual(len(will_be_altered), 0) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) will_be_altered = set_dag_run_state_to_failed(self.dag1, date, commit=False) # Only the running task will be altered. self.assertEqual(len(will_be_altered), 1) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) will_be_altered = set_dag_run_state_to_success(self.dag1, date, commit=False) # All except the SUCCESS task should be altered. self.assertEqual(len(will_be_altered), 5) self._verify_dag_run_state(self.dag1, date, State.RUNNING) self._verify_task_instance_states_remain_default(dr) def test_set_state_with_multiple_dagruns(self): self.dag2.create_dagrun( run_id='manual__' + datetime.now().isoformat(), state=State.FAILED, execution_date=self.execution_dates[0], session=self.session ) self.dag2.create_dagrun( run_id='manual__' + datetime.now().isoformat(), state=State.FAILED, execution_date=self.execution_dates[1], session=self.session ) self.dag2.create_dagrun( run_id='manual__' + datetime.now().isoformat(), state=State.RUNNING, execution_date=self.execution_dates[2], session=self.session ) altered = set_dag_run_state_to_success(self.dag2, self.execution_dates[1], commit=True) # Recursively count number of tasks in the dag def count_dag_tasks(dag): count = len(dag.tasks) subdag_counts = [count_dag_tasks(subdag) for subdag in dag.subdags] count += sum(subdag_counts) return count self.assertEqual(len(altered), count_dag_tasks(self.dag2)) self._verify_dag_run_state(self.dag2, self.execution_dates[1], State.SUCCESS) # Make sure other dag status are not changed models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0]) self._verify_dag_run_state(self.dag2, self.execution_dates[0], State.FAILED) models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2]) self._verify_dag_run_state(self.dag2, self.execution_dates[2], State.RUNNING) def test_set_dag_run_state_edge_cases(self): # Dag does not exist altered = set_dag_run_state_to_success(None, self.execution_dates[0]) self.assertEqual(len(altered), 0) altered = set_dag_run_state_to_failed(None, self.execution_dates[0]) self.assertEqual(len(altered), 0) altered = set_dag_run_state_to_running(None, self.execution_dates[0]) self.assertEqual(len(altered), 0) # Invalid execution date altered = set_dag_run_state_to_success(self.dag1, None) self.assertEqual(len(altered), 0) altered = set_dag_run_state_to_failed(self.dag1, None) self.assertEqual(len(altered), 0) altered = set_dag_run_state_to_running(self.dag1, None) self.assertEqual(len(altered), 0) # This will throw AssertionError since dag.latest_execution_date # need to be 0 does not exist. self.assertRaises(AssertionError, set_dag_run_state_to_success, self.dag1, timezone.make_naive(self.execution_dates[0])) # altered = set_dag_run_state_to_success(self.dag1, self.execution_dates[0]) # DagRun does not exist # This will throw AssertionError since dag.latest_execution_date does not exist self.assertRaises(AssertionError, set_dag_run_state_to_success, self.dag1, self.execution_dates[0]) def tearDown(self): self.dag1.clear() self.dag2.clear() self.session.query(models.DagRun).delete() self.session.query(models.TaskInstance).delete() self.session.query(models.DagStat).delete() self.session.commit() self.session.close()
def reset(dag_id=TEST_DAG_ID): session = Session() tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) tis.delete() session.commit() session.close()
def roles(self): return Session.query(Role).filter_by(name=self.role)
for parent_graph in graphs: if parent_id == parent_graph.sql_sensor_id: graph.set_upstream(parent_graph) return subdag def execute_statement(**kw): hook = BaseHook.get_connection(kw['conn_id']).get_hook() logging.info('Executing: %s', kw['sql']) hook.run(kw['sql']) session = Session() dags = session.query(SQLDag) \ .filter(SQLSensor.enabled == True) \ .all() for _dag in dags: default_args['start_date'] = datetime.combine( _dag.start_date, dt.time(), ) dag = DAG(_dag.id, default_args=default_args, schedule_interval=_dag.schedule_interval, ) subdag_name = 'subdag_with_sensors' sub_dag = SubDagOperator( subdag=make_subdag(subdag_id=subdag_name,
def setUpClass(cls): super(TestDagRunsEndpoint, cls).setUpClass() session = Session() session.query(DagRun).delete() session.commit() session.close()
class TestVariableView(unittest.TestCase): CREATE_ENDPOINT = '/admin/variable/new/?url=/admin/variable/' @classmethod def setUpClass(cls): super(TestVariableView, cls).setUpClass() session = Session() session.query(models.Variable).delete() session.commit() session.close() def setUp(self): super(TestVariableView, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.variable = { 'key': 'test_key', 'val': 'text_val', 'is_encrypted': True } def tearDown(self): self.session.query(models.Variable).delete() self.session.commit() self.session.close() super(TestVariableView, self).tearDown() def test_can_handle_error_on_decrypt(self): # create valid variable response = self.app.post( self.CREATE_ENDPOINT, data=self.variable, follow_redirects=True, ) self.assertEqual(response.status_code, 200) # update the variable with a wrong value, given that is encrypted Var = models.Variable (self.session.query(Var).filter( Var.key == self.variable['key']).update( {'val': 'failed_value_not_encrypted'}, synchronize_session=False)) self.session.commit() # retrieve Variables page, should not fail and contain the Invalid # label for the variable response = self.app.get('/admin/variable', follow_redirects=True) self.assertEqual(response.status_code, 200) self.assertEqual(self.session.query(models.Variable).count(), 1) def test_xss_prevention(self): xss = "/admin/airflow/variables/asdf<img%20src=''%20onerror='alert(1);'>" response = self.app.get( xss, follow_redirects=True, ) self.assertEqual(response.status_code, 404) self.assertNotIn("<img src='' onerror='alert(1);'>", response.data.decode("utf-8"))
class TestConnectionModelView(unittest.TestCase): CREATE_ENDPOINT = '/admin/connection/new/?url=/admin/connection/' CONN_ID = "new_conn" CONN = { "conn_id": CONN_ID, "conn_type": "http", "host": "https://example.com", } @classmethod def setUpClass(cls): super(TestConnectionModelView, cls).setUpClass() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] cls.app = app.test_client() def setUp(self): self.session = Session() def tearDown(self): self.session.query(models.Connection) \ .filter(models.Connection.conn_id == self.CONN_ID).delete() self.session.commit() self.session.close() super(TestConnectionModelView, self).tearDown() def test_create(self): response = self.app.post( self.CREATE_ENDPOINT, data=self.CONN, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertEqual( self.session.query(models.Connection).filter( models.Connection.conn_id == self.CONN_ID).count(), 1) def test_create_error(self): response = self.app.post( self.CREATE_ENDPOINT, data={"conn_type": "http"}, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertIn(b'has-error', response.data) self.assertEqual( self.session.query(models.Connection).filter( models.Connection.conn_id == self.CONN_ID).count(), 0) def test_create_extras(self): data = self.CONN.copy() data.update({ "conn_type": "google_cloud_platform", "extra__google_cloud_platform__num_retries": "2", }) response = self.app.post( self.CREATE_ENDPOINT, data=data, follow_redirects=True, ) self.assertEqual(response.status_code, 200) conn = self.session.query(models.Connection).filter( models.Connection.conn_id == self.CONN_ID).one() self.assertEqual( conn.extra_dejson['extra__google_cloud_platform__num_retries'], 2) def test_create_extras_empty_field(self): data = self.CONN.copy() data.update({ "conn_type": "google_cloud_platform", "extra__google_cloud_platform__num_retries": "", }) response = self.app.post( self.CREATE_ENDPOINT, data=data, follow_redirects=True, ) self.assertEqual(response.status_code, 200) conn = self.session.query(models.Connection).filter( models.Connection.conn_id == self.CONN_ID).one() self.assertIsNone( conn.extra_dejson['extra__google_cloud_platform__num_retries'])
def setUpClass(cls): super(TestVariableView, cls).setUpClass() session = Session() session.query(models.Variable).delete() session.commit() session.close()
def execute(self, context=None): # noqa: C901 """Executes the 'apply_schema_and_profile_permissions' method from schema-tools to set the database permissions on objects to roles""" # setup logger so output can be added to the Airflow logs logger = logging.getLogger(__name__) # setup database connection where the database objects are present engine = _get_engine(self.db_conn) # Option ONE: batch grant if self.batch_ind: # get current datetime and make it aware of the TZ (mandatory) now = pendulum.now("Europe/Amsterdam") # calculate the delta between current datetime and specified time window time_window_hour = int(self.batch_timewindow.split(":")[0]) time_window_minutes = int(self.batch_timewindow.split(":")[1]) delta = now - timedelta(hours=time_window_hour, minutes=time_window_minutes) logger.info("the time window is set starting at: %s till now", delta) # setup an Airflow session to access Airflow repository data session = Session() # get list of dags that meet time window and state outcome # it uses the Airflow DagRun class to get data executed_dags_after_delta = [ dag.dag_id for dag in session.query(DagRun).filter( DagRun.end_date > delta).filter(DagRun._state == "success") # exclude the dag itself that calls this batch grant method .filter((DagRun.dag_id != "airflow_db_permissions")) # exclude the update_dag, it does not contain DB objects to grant .filter((DagRun.dag_id != "update_dags")) ] if executed_dags_after_delta: for dataset_name in executed_dags_after_delta: # get real datasetname from DAG_DATASET constant, if dag_id != dataschema name for key in DAG_DATASET.keys(): if key in dataset_name: dataset_name = DAG_DATASET[key] break logger.info("set grants for %s", dataset_name) try: ams_schema = schema_defs_from_url( schemas_url=self.schema_url, dataset_name=dataset_name, ) apply_schema_and_profile_permissions( engine=engine, pg_schema=self.db_schema, ams_schema=ams_schema, profiles=self.profiles, role=self.role, scope=self.scope, dry_run=self.dry_run, create_roles=self.create_roles, revoke=self.revoke, ) except HTTPError: logger.error("Could not get data schema for %s", dataset_name) continue else: logger.error( "Nothing to grant, no finished dags detected within time window of %s", self.batch_timewindow, ) # Option TWO: grant on single dataset (can be used as a final step within a dag run) elif self.dataset_name and not self.batch_ind: try: ams_schema = schema_defs_from_url( schemas_url=self.schema_url, dataset_name=self.dataset_name) apply_schema_and_profile_permissions( engine=engine, pg_schema=self.db_schema, ams_schema=ams_schema, profiles=self.profiles, role=self.role, scope=self.scope, dry_run=self.dry_run, create_roles=self.create_roles, revoke=self.revoke, ) except HTTPError: logger.error("Could not get data schema for %s", self.dataset_name) pass
class TestKnownEventView(unittest.TestCase): CREATE_ENDPOINT = '/admin/knownevent/new/?url=/admin/knownevent/' @classmethod def setUpClass(cls): super(TestKnownEventView, cls).setUpClass() session = Session() session.query(models.KnownEvent).delete() session.query(models.User).delete() session.commit() user = models.User(username='******') session.add(user) session.commit() cls.user_id = user.id session.close() def setUp(self): super(TestKnownEventView, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) app.config['WTF_CSRF_METHODS'] = [] self.app = app.test_client() self.session = Session() self.known_event = { 'label': 'event-label', 'event_type': '1', 'start_date': '2017-06-05 12:00:00', 'end_date': '2017-06-05 13:00:00', 'reported_by': self.user_id, 'description': '', } def tearDown(self): self.session.query(models.KnownEvent).delete() self.session.commit() self.session.close() super(TestKnownEventView, self).tearDown() @classmethod def tearDownClass(cls): session = Session() session.query(models.User).delete() session.commit() session.close() super(TestKnownEventView, cls).tearDownClass() def test_create_known_event(self): response = self.app.post( self.CREATE_ENDPOINT, data=self.known_event, follow_redirects=True, ) self.assertEqual(response.status_code, 200) self.assertEqual(self.session.query(models.KnownEvent).count(), 1) def test_create_known_event_with_end_data_earlier_than_start_date(self): self.known_event['end_date'] = '2017-06-05 11:00:00' response = self.app.post( self.CREATE_ENDPOINT, data=self.known_event, follow_redirects=True, ) self.assertIn( 'Field must be greater than or equal to Start Date.', response.data.decode('utf-8'), ) self.assertEqual(self.session.query(models.KnownEvent).count(), 0)
def tearDown(self): session = Session() session.query(models.TaskInstance).filter_by(dag_id=TEST_DAG_ID).delete() session.query(TaskFail).filter_by(dag_id=TEST_DAG_ID).delete() session.commit() session.close()
def tearDown(self): session = Session() session.query(DagRun).delete() session.commit() session.close() super(TestDagRunsEndpoint, self).tearDown()
def tearDownClass(cls): session = Session() session.query(models.User).delete() session.commit() session.close() super(TestChartModelView, cls).tearDownClass()
def reset_dr_db(dag_id): session = Session() dr = session.query(models.DagRun).filter_by(dag_id=dag_id) dr.delete() session.commit() session.close()