def test_get_templated_fields(self, templated_field, expected_rendered_field): """ Test that template_fields are rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_templated_fields """ dag = DAG("test_serialized_rendered_fields", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command=templated_field) ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) self.assertEqual(ti.dag_id, rtif.dag_id) self.assertEqual(ti.task_id, rtif.task_id) self.assertEqual(ti.execution_date, rtif.execution_date) self.assertEqual(expected_rendered_field, rtif.rendered_fields.get("bash_command")) with create_session() as session: session.add(rtif) self.assertEqual( {"bash_command": expected_rendered_field, "env": None}, RTIF.get_templated_fields(ti=ti) ) # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None with dag: task_2 = BashOperator(task_id="test2", bash_command=templated_field) ti2 = TI(task_2, EXECUTION_DATE) self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
def test_default_pool_open_slots(self): set_default_pool_slots(5) assert 5 == Pool.get_default_pool().open_slots() dag = DAG( dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag) op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(ti1) session.add(ti2) session.commit() session.close() assert 2 == Pool.get_default_pool().open_slots() assert { "default_pool": { "open": 2, "queued": 2, "total": 5, "running": 1, } } == Pool.slots_stats()
def test_write(self): """ Test records can be written and overwritten """ Variable.set(key="test_key", value="test_val") session = settings.Session() result = session.query(RTIF).all() assert [] == result with DAG("test_write", start_date=START_DATE): task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif.write() result = (session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif.dag_id, RTIF.task_id == rtif.task_id, RTIF.execution_date == rtif.execution_date, ).first()) assert ('test_write', 'test', { 'bash_command': 'echo test_val', 'env': None }) == result # Test that overwrite saves new values to the DB Variable.delete("test_key") Variable.set(key="test_key", value="test_val_updated") with DAG("test_write", start_date=START_DATE): updated_task = BashOperator( task_id="test", bash_command="echo {{ var.value.test_key }}") rtif_updated = RTIF( TI(task=updated_task, execution_date=EXECUTION_DATE)) rtif_updated.write() result_updated = (session.query( RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).filter( RTIF.dag_id == rtif_updated.dag_id, RTIF.task_id == rtif_updated.task_id, RTIF.execution_date == rtif_updated.execution_date, ).first()) assert ( 'test_write', 'test', { 'bash_command': 'echo test_val_updated', 'env': None }, ) == result_updated
def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator: for map_index in indexes: ti = TI(task, run_id=self.run_id, map_index=map_index) task_instance_mutation_hook(ti) created_counts[ti.operator] += 1 yield ti
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_list = [ RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) for num in range(rtif_num) ] session.add_all(rtif_list) session.commit() result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() for rtif in rtif_list: self.assertIn(rtif, result) self.assertEqual(rtif_num, len(result)) # Verify old records are deleted and only 'num_to_keep' records are kept with assert_queries_count(expected_query_count): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertEqual(remaining_rtifs, len(result))
def verify_integrity(self, session: Session = None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. " "Marking it as removed.", ti, dag) Stats.incr("task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously " "removed from DAG '%s'", ti, dag) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr("task_instance_created-{}".format(task.task_type), 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) try: session.commit() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') session.rollback()
def verify_integrity(self, session: Session = NEW_SESSION): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) ti = TI(task, execution_date=None, run_id=self.run_id) task_instance_mutation_hook(ti) session.add(ti) try: session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', start_date=DEFAULT_DATE, ) t1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') t2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=t1, execution_date=DEFAULT_DATE) ti2 = TI(task=t2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(3, pool.open_slots())
def test_open_slots(self): pool = Pool(pool='test_pool', slots=5) dag = DAG( dag_id='test_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(3, pool.open_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.running_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter self.assertEqual( { "default_pool": { "open": 128, "queued": 0, "total": 128, "running": 0, }, "test_pool": { "open": 3, "queued": 1, "running": 1, "total": 5, }, }, pool.slots_stats(), )
def test_default_pool_open_slots(self): set_default_pool_slots(5) self.assertEqual(5, Pool.get_default_pool().open_slots()) dag = DAG( dag_id='test_default_pool_open_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag) op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2) ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(2, Pool.get_default_pool().open_slots())
def test_infinite_slots(self): pool = Pool(pool='test_pool', slots=-1) dag = DAG( dag_id='test_infinite_slots', start_date=DEFAULT_DATE, ) op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool') op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool') ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) ti1.state = State.RUNNING ti2.state = State.QUEUED session = settings.Session session.add(pool) session.add(ti1) session.add(ti2) session.commit() session.close() self.assertEqual(float('inf'), pool.open_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.used_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(1, pool.queued_slots()) # pylint: disable=no-value-for-parameter self.assertEqual(2, pool.occupied_slots()) # pylint: disable=no-value-for-parameter
def test_delete_old_records(self): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_1 = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) rtif_2 = RTIF( TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=1))) rtif_3 = RTIF( TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=2))) session.add(rtif_1) session.add(rtif_2) session.add(rtif_3) session.commit() result = session.query(RTIF)\ .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertIn(rtif_1, result) self.assertIn(rtif_2, result) self.assertIn(rtif_3, result) self.assertEqual(3, len(result)) # Verify old records are deleted and only 1 record is kept RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=1) result = session.query(RTIF) \ .filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() self.assertEqual(1, len(result)) self.assertEqual(rtif_3.execution_date, result[0].execution_date)
def verify_integrity(self, session=None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. """ dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = [] for ti in tis: task_instance_mutation_hook(ti) task_ids.append(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning("Failed to get task '{}' for dag '{}'. " "Marking it as removed.".format(ti, dag)) Stats.incr("task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '{}' which was previously " "removed from DAG '{}'".format(ti, dag)) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr( "task_instance_created-{}".format(task.__class__.__name__), 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) session.commit()
def test_skip_all_except(self): dag = DAG( 'dag_test_skip_all_except', start_date=DEFAULT_DATE, ) with dag: task1 = DummyOperator(task_id='task1') task2 = DummyOperator(task_id='task2') task3 = DummyOperator(task_id='task3') task1 >> [task2, task3] ti1 = TI(task1, execution_date=DEFAULT_DATE) ti2 = TI(task2, execution_date=DEFAULT_DATE) ti3 = TI(task3, execution_date=DEFAULT_DATE) SkipMixin().skip_all_except(ti=ti1, branch_task_ids=['task2']) def get_state(ti): ti.refresh_from_db() return ti.state assert get_state(ti2) == State.NONE assert get_state(ti3) == State.SKIPPED
def test_redact(self, redact): dag = DAG("test_ritf_redact", start_date=START_DATE) with dag: task = BashOperator( task_id="test", bash_command="echo {{ var.value.api_key }}", env={'foo': 'secret', 'other_api_key': 'masked based on key name'}, ) redact.side_effect = [ 'val 1', 'val 2', ] ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) assert rtif.rendered_fields == { 'bash_command': 'val 1', 'env': 'val 2', }
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ session = settings.Session() dag = DAG("test_delete_old_records", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo {{ ds }}") rtif_list = [ RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) for num in range(rtif_num) ] session.add_all(rtif_list) session.commit() result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() for rtif in rtif_list: assert rtif in result assert rtif_num == len(result) # Verify old records are deleted and only 'num_to_keep' records are kept # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records expected_query_count_based_on_db = ( expected_query_count + 1 if session.bind.dialect.name == "mssql" and expected_query_count != 0 else expected_query_count ) with assert_queries_count(expected_query_count_based_on_db): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() assert remaining_rtifs == len(result)
def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook): """ Test that k8s_pod_yaml is rendered correctly, stored in the Database, and are correctly fetched using RTIF.get_k8s_pod_yaml """ dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE) with dag: task = BashOperator(task_id="test", bash_command="echo hi") ti = TI(task=task, execution_date=EXECUTION_DATE) rtif = RTIF(ti=ti) # Test that pod_mutation_hook is called mock_pod_mutation_hook.assert_called_once_with(mock.ANY) self.assertEqual(ti.dag_id, rtif.dag_id) self.assertEqual(ti.task_id, rtif.task_id) self.assertEqual(ti.execution_date, rtif.execution_date) expected_pod_yaml = { 'metadata': { 'annotations': { 'dag_id': 'test_get_k8s_pod_yaml', 'execution_date': '2019-01-01T00:00:00+00:00', 'task_id': 'test', 'try_number': '1', }, 'labels': { 'airflow-worker': 'worker-config', 'airflow_version': version, 'dag_id': 'test_get_k8s_pod_yaml', 'execution_date': '2019-01-01T00_00_00_plus_00_00', 'kubernetes_executor': 'True', 'task_id': 'test', 'try_number': '1', }, 'name': mock.ANY, 'namespace': 'default', }, 'spec': { 'containers': [ { 'command': [ 'airflow', 'tasks', 'run', 'test_get_k8s_pod_yaml', 'test', '2019-01-01T00:00:00+00:00', ], 'image': ':', 'name': 'base', } ] }, } self.assertEqual(expected_pod_yaml, rtif.k8s_pod_yaml) with create_session() as session: session.add(rtif) self.assertEqual(expected_pod_yaml, RTIF.get_k8s_pod_yaml(ti=ti)) # Test the else part of get_k8s_pod_yaml # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None with dag: task_2 = BashOperator(task_id="test2", bash_command="echo hello") ti2 = TI(task_2, EXECUTION_DATE) self.assertIsNone(RTIF.get_k8s_pod_yaml(ti=ti2))
def create_ti(task: "BaseOperator") -> TI: ti = TI(task, run_id=self.run_id) task_instance_mutation_hook(ti) created_counts[ti.operator] += 1 return ti