def test_write(self): """ Test records can be written and overwritten """ Variable.set(key="test_key", value="test_val") session = settings.Session() result = session.query(RTIF).all() assert [] == result with DAG("test_write", start_date=START_DATE): task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif.write() result = (session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif.dag_id, RTIF.task_id == rtif.task_id, RTIF.execution_date == rtif.execution_date, ).first()) assert ('test_write', 'test', { 'bash_command': 'echo test_val', 'env': None }) == result # Test that overwrite saves new values to the DB Variable.delete("test_key") Variable.set(key="test_key", value="test_val_updated") with DAG("test_write", start_date=START_DATE): updated_task = BashOperator( task_id="test", bash_command="echo {{ var.value.test_key }}") rtif_updated = RTIF( TI(task=updated_task, execution_date=EXECUTION_DATE)) rtif_updated.write() result_updated = (session.query( RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif_updated.dag_id, RTIF.task_id == rtif_updated.task_id, RTIF.execution_date == rtif_updated.execution_date, ).first()) assert ( 'test_write', 'test', { 'bash_command': 'echo test_val_updated', 'env': None }, ) == result_updated
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_list = [ RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) for num in range(rtif_num) ] session.add_all(rtif_list) session.commit() result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() for rtif in rtif_list: self.assertIn(rtif, result) self.assertEqual(rtif_num, len(result)) # Verify old records are deleted and only 'num_to_keep' records are kept with assert_queries_count(expected_query_count): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertEqual(remaining_rtifs, len(result))
def test_get_templated_fields(self, templated_field, expected_rendered_field): """ Test that template_fields are rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_templated_fields """ dag = DAG("test_serialized_rendered_fields", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command=templated_field) ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) self.assertEqual(ti.dag_id, rtif.dag_id) self.assertEqual(ti.task_id, rtif.task_id) self.assertEqual(ti.execution_date, rtif.execution_date) self.assertEqual(expected_rendered_field, rtif.rendered_fields.get("bash_command")) with create_session() as session: session.add(rtif) self.assertEqual( {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti) ) # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None with dag: task_2 = BashOperator(task_id="test2", bash_command=templated_field) ti2 = TI(task_2, EXECUTION_DATE) self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
def test_delete_old_records(self): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_1 = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif_2 = RTIF( TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=1))) rtif_3 = RTIF( TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=2))) session.add(rtif_1) session.add(rtif_2) session.add(rtif_3) session.commit() result = session.query(RTIF)\ .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertIn(rtif_1, result) self.assertIn(rtif_2, result) self.assertIn(rtif_3, result) self.assertEqual(3, len(result)) # Verify old records are deleted and only 1 record is kept RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=1) result = session.query(RTIF) \ .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertEqual(1, len(result)) self.assertEqual(rtif_3.execution_date, result[0].execution_date)
def test_rendered_view(self, get_dag_function): """ Test that the Rendered View contains the values from RenderedTaskInstanceFields """ get_dag_function.return_value = SerializedDagModel.get( self.dag.dag_id).dag self.assertEqual(self.task1.bash_command, '{{ task_instance_key_str }}') ti = TaskInstance(self.task1, self.default_date) with create_session() as session: session.add(RTIF(ti)) url = ( '/admin/airflow/rendered?task_id=task1&dag_id=testdag&execution_date={}' .format(self.percent_encode(self.default_date))) resp = self.app.get(url, follow_redirects=True) self.assertIn("testdag__task1__20200301", resp.data.decode('utf-8'))
def test_redact(self, redact): dag = DAG("test_ritf_redact", start_date=START_DATE) with dag: task = BashOperator( task_id="test", bash_command="echo {{ var.value.api_key }}", env={'foo': 'secret', 'other_api_key': 'masked based on key name'}, ) redact.side_effect = [ 'val 1', 'val 2', ] ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) assert rtif.rendered_fields == { 'bash_command': 'val 1', 'env': 'val 2', }
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_list = [ RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) for num in range(rtif_num) ] session.add_all(rtif_list) session.commit() result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() for rtif in rtif_list: assert rtif in result assert rtif_num == len(result) # Verify old records are deleted and only 'num_to_keep' records are kept # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records expected_query_count_based_on_db = ( expected_query_count + 1 if session.bind.dialect.name == "mssql" and expected_query_count != 0 else expected_query_count ) with assert_queries_count(expected_query_count_based_on_db): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() assert remaining_rtifs == len(result)
def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook): """ Test that k8s_pod_yaml is rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_k8s_pod_yaml """ dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo hi") ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) # Test that pod_mutation_hook is called mock_pod_mutation_hook.assert_called_once_with(mock.ANY) self.assertEqual(ti.dag_id, rtif.dag_id) self.assertEqual(ti.task_id, rtif.task_id) self.assertEqual(ti.execution_date, rtif.execution_date) expected_pod_yaml = { 'metadata': { 'annotations': { 'dag_id': 'test_get_k8s_pod_yaml', 'execution_date': '2019-01-01T00:00:00+00:00', 'task_id': 'test', 'try_number': '1', }, 'labels': { 'airflow-worker': 'worker-config', 'airflow_version': version, 'dag_id': 'test_get_k8s_pod_yaml', 'execution_date': '2019-01-01T00_00_00_plus_00_00', 'kubernetes_executor': 'True', 'task_id': 'test', 'try_number': '1', }, 'name': mock.ANY, 'namespace': 'default', }, 'spec': { 'containers': [ { 'command': [ 'airflow', 'tasks', 'run', 'test_get_k8s_pod_yaml', 'test', '2019-01-01T00:00:00+00:00', ], 'image': ':', 'name': 'base', } ] }, } self.assertEqual(expected_pod_yaml, rtif.k8s_pod_yaml) with create_session() as session: session.add(rtif) self.assertEqual(expected_pod_yaml, RTIF.get_k8s_pod_yaml(ti=ti)) # Test the else part of get_k8s_pod_yaml # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None with dag: task_2 = BashOperator(task_id="test2", bash_command="echo hello") ti2 = TI(task_2, EXECUTION_DATE) self.assertIsNone(RTIF.get_k8s_pod_yaml(ti=ti2))