def test_bigquery_operator_extra_serialized_field_when_multiple_queries( self): with self.dag: BigQueryExecuteQueryOperator( task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], ) serialized_dag = SerializedDAG.to_dict(self.dag) assert "sql" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] assert getattr(simple_task, "sql") == [ 'SELECT * FROM test_table', 'SELECT * FROM test_table2' ] ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ { 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { 'index': 0 } }, { 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { 'index': 1 } }, ] # Check DeSerialized version of operator link assert isinstance( list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink) ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) assert {'BigQuery Console #1', 'BigQuery Console #2' } == simple_task.operator_extra_link_dict.keys() assert 'https://console.cloud.google.com/bigquery?j=123' == simple_task.get_extra_links( DEFAULT_DATE, 'BigQuery Console #1') assert 'https://console.cloud.google.com/bigquery?j=45' == simple_task.get_extra_links( DEFAULT_DATE, 'BigQuery Console #2')
def test_serialization(self): """Serialization and deserialization should work for every DAG and Operator.""" dags = collect_dags() serialized_dags = {} for _, v in dags.items(): dag = SerializedDAG.to_dict(v) SerializedDAG.validate_schema(dag) serialized_dags[v.dag_id] = dag # Compares with the ground truth of JSON string. self.validate_serialized_dag(serialized_dags['simple_dag'], serialized_simple_dag_ground_truth)
def test_extra_serialized_field_and_multiple_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. This tests also depends on GoogleLink() registered as a plugin in tests/plugins/test_plugin.py The function tests that if extra operator links are registered in plugin in ``operator_extra_links`` and the same is also defined in the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ test_date = datetime(2019, 8, 1) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"]) serialized_dag = SerializedDAG.to_dict(dag) self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] self.assertEqual(getattr(simple_task, "bash_command"), ["echo", "true"]) ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link only contains the inbuilt Op Link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [ {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}}, {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}}, ] ) # Test all the extra_links are set self.assertCountEqual(simple_task.extra_links, [ 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google']) ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) # Test Deserialized inbuilt link #1 custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1") self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_1', custom_inbuilt_link) # Test Deserialized inbuilt link #2 custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #2") self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_2', custom_inbuilt_link) # Test Deserialized link registered via Airflow Plugin google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name) self.assertEqual("https://www.google.com", google_link_from_plugin)
def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date): dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), end_date=dag_end_date) BaseOperator(task_id='simple_task', dag=dag, end_date=task_end_date) serialized_dag = SerializedDAG.to_dict(dag) if not task_end_date or dag_end_date <= task_end_date: # If dag.end_date < task.end_date -> task.end_date=dag.end_date # because of the logic in dag.add_task() assert "end_date" not in serialized_dag["dag"]["tasks"][0] else: assert "end_date" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert simple_task.end_date == expected_task_end_date
def test_deserialization_start_date(self, dag_start_date, task_start_date, expected_task_start_date): dag = DAG(dag_id='simple_dag', start_date=dag_start_date) BaseOperator(task_id='simple_task', dag=dag, start_date=task_start_date) serialized_dag = SerializedDAG.to_dict(dag) if not task_start_date or dag_start_date >= task_start_date: # If dag.start_date > task.start_date -> task.start_date=dag.start_date # because of the logic in dag.add_task() self.assertNotIn("start_date", serialized_dag["dag"]["tasks"][0]) else: self.assertIn("start_date", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] self.assertEqual(simple_task.start_date, expected_task_start_date)
def test_extra_serialized_field_and_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. This tests also depends on GoogleLink() registered as a plugin in tests/plugins/test_plugin.py The function tests that if extra operator links are registered in plugin in ``operator_extra_links`` and the same is also defined in the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ test_date = datetime(2019, 8, 1) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command="true") serialized_dag = SerializedDAG.to_dict(dag) assert "bash_command" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert getattr(simple_task, "bash_command") == "true" ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link only contains the inbuilt Op Link assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [{ 'tests.test_utils.mock_operators.CustomOpLink': {} }] # Test all the extra_links are set assert set(simple_task.extra_links) == { 'Google Custom', 'airflow', 'github', 'google' } ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', "dummy_value_1") # Test Deserialized inbuilt link custom_inbuilt_link = simple_task.get_extra_links( test_date, CustomOpLink.name) assert 'http://google.com/custom_base_link?search=dummy_value_1' == custom_inbuilt_link # Test Deserialized link registered via Airflow Plugin google_link_from_plugin = simple_task.get_extra_links( test_date, GoogleLink.name) assert "https://www.google.com" == google_link_from_plugin
def test_task_params_roundtrip(self, val, expected_val): """ Test that params work both on Serialized DAGs & Tasks """ dag = DAG(dag_id='simple_dag') BaseOperator(task_id='simple_task', dag=dag, params=val, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) if val: assert "params" in serialized_dag["dag"]["tasks"][0] else: assert "params" not in serialized_dag["dag"]["tasks"][0] deserialized_dag = SerializedDAG.from_dict(serialized_dag) deserialized_simple_task = deserialized_dag.task_dict["simple_task"] assert expected_val == deserialized_simple_task.params
def test_templated_fields_exist_in_serialized_dag(self, templated_field, expected_field): """ Test that templated_fields exists for all Operators in Serialized DAG Since we don't want to inflate arbitrary python objects (it poses a RCE/security risk etc.) we want check that non-"basic" objects are turned in to strings after deserializing. """ dag = DAG("test_serialized_template_fields", start_date=datetime(2019, 8, 1)) with dag: BashOperator(task_id="test", bash_command=templated_field) serialized_dag = SerializedDAG.to_dict(dag) deserialized_dag = SerializedDAG.from_dict(serialized_dag) deserialized_test_task = deserialized_dag.task_dict["test"] assert expected_field == getattr(deserialized_test_task, "bash_command")
def test_task_group_serialization(self): """ Test TaskGroup serialization/deserialization. """ from airflow.operators.dummy import DummyOperator from airflow.utils.task_group import TaskGroup execution_date = datetime(2020, 1, 1) with DAG("test_task_group_serialization", start_date=execution_date) as dag: task1 = DummyOperator(task_id="task1") with TaskGroup("group234") as group234: _ = DummyOperator(task_id="task2") with TaskGroup("group34") as group34: _ = DummyOperator(task_id="task3") _ = DummyOperator(task_id="task4") task5 = DummyOperator(task_id="task5") task1 >> group234 group34 >> task5 dag_dict = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(dag_dict) json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) assert serialized_dag.task_group.children assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys() def check_task_group(node): try: children = node.children.values() except AttributeError: # Round-trip serialization and check the result expected_serialized = SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id)) expected_deserialized = SerializedBaseOperator.deserialize_operator(expected_serialized) expected_dict = SerializedBaseOperator.serialize_operator(expected_deserialized) assert node assert SerializedBaseOperator.serialize_operator(node) == expected_dict return for child in children: check_task_group(child) check_task_group(serialized_dag.task_group)
def test_dag_params_roundtrip(self, val, expected_val): """ Test that params work both on Serialized DAGs & Tasks """ dag = DAG(dag_id='simple_dag', params=val) BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) if val: self.assertIn("params", serialized_dag["dag"]) else: self.assertNotIn("params", serialized_dag["dag"]) deserialized_dag = SerializedDAG.from_dict(serialized_dag) deserialized_simple_task = deserialized_dag.task_dict["simple_task"] self.assertEqual(expected_val, deserialized_dag.params) self.assertEqual(expected_val, deserialized_simple_task.params)
def test_console_extra_link_serialized_field(self): with self.dag: training_op = MLEngineStartTrainingJobOperator( **self.TRAINING_DEFAULT_ARGS) serialized_dag = SerializedDAG.to_dict(self.dag) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']] # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{ "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {} }], ) # Check DeSerialized version of operator link self.assertIsInstance( list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] gcp_metadata = { "job_id": job_id, "project_id": project_id, } ti = TaskInstance( task=training_op, execution_date=DEFAULT_DATE, ) ti.xcom_push(key='gcp_metadata', value=gcp_metadata) self.assertEqual( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}", simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name), ) self.assertEqual( '', simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), )
def test_bigquery_operator_extra_serialized_field_when_single_query(self): with self.dag: BigQueryExecuteQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', ) serialized_dag = SerializedDAG.to_dict(self.dag) self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] self.assertEqual(getattr(simple_task, "sql"), 'SELECT * FROM test_table') ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{ 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {} }], ) # Check DeSerialized version of operator link self.assertIsInstance( list(simple_task.operator_extra_links)[0], BigQueryConsoleLink) ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) ti.xcom_push('job_id', 12345) # check for positive case url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) self.assertEqual(url, 'https://console.cloud.google.com/bigquery?j=12345') # check for negative case url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name) self.assertEqual(url2, '')
def test_dag_on_success_callback_roundtrip(self, passed_success_callback, expected_value): """ Test that when on_success_callback is passed to the DAG, has_on_success_callback is stored in Serialized JSON blob. And when it is de-serialized dag.has_on_success_callback is set to True. When the callback is not set, has_on_success_callback should not be stored in Serialized blob and so default to False on de-serialization """ dag = DAG(dag_id='test_dag_on_success_callback_roundtrip', **passed_success_callback) BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1)) serialized_dag = SerializedDAG.to_dict(dag) if expected_value: assert "has_on_success_callback" in serialized_dag["dag"] else: assert "has_on_success_callback" not in serialized_dag["dag"] deserialized_dag = SerializedDAG.from_dict(serialized_dag) assert deserialized_dag.has_on_success_callback is expected_value
def test_edge_info_serialization(self): """ Tests edge_info serialization/deserialization. """ from airflow.operators.dummy import DummyOperator from airflow.utils.edgemodifier import Label with DAG("test_edge_info_serialization", start_date=datetime(2020, 1, 1)) as dag: task1 = DummyOperator(task_id="task1") task2 = DummyOperator(task_id="task2") task1 >> Label("test label") >> task2 # pylint: disable=W0106 dag_dict = SerializedDAG.to_dict(dag) SerializedDAG.validate_schema(dag_dict) json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) self.validate_deserialized_dag(json_dag, dag) serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) assert serialized_dag.edge_info == dag.edge_info
def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.fileloc self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.last_updated = timezone.utcnow() dag_data = SerializedDAG.to_dict(dag) dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8") self.dag_hash = hashlib.md5(dag_data_json).hexdigest() if COMPRESS_SERIALIZED_DAGS: self._data = None self._data_compressed = zlib.compress(dag_data_json) else: self._data = dag_data self._data_compressed = None # serve as cache so no need to decompress and load, when accessing data field # when COMPRESS_SERIALIZED_DAGS is True self.__data_cache = dag_data
def test_dagrun_update_state_with_handle_callback_failure(self): def on_failure_callable(context): self.assertEqual( context['dag_run'].dag_id, 'test_dagrun_update_state_with_handle_callback_failure') dag = DAG( dag_id='test_dagrun_update_state_with_handle_callback_failure', start_date=datetime.datetime(2017, 1, 1), on_failure_callback=on_failure_callable, ) dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag) dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag) dag_task1.set_downstream(dag_task2) initial_task_states = { 'test_state_succeeded1': State.SUCCESS, 'test_state_failed2': State.FAILED, } # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) _, callback = dag_run.update_state(execute_callbacks=False) self.assertEqual(State.FAILED, dag_run.state) # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() assert callback == DagCallbackRequest( full_filepath=dag_run.dag.fileloc, dag_id="test_dagrun_update_state_with_handle_callback_failure", execution_date=dag_run.execution_date, is_failure_callback=True, msg="task_failure", )
def test_deserialization_with_dag_context(self): with DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1, tzinfo=timezone.utc)) as dag: BaseOperator(task_id='simple_task') # should not raise RuntimeError: dictionary changed size during iteration SerializedDAG.to_dict(dag)
def __init__(self, dag): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc) self.data = SerializedDAG.to_dict(dag) self.last_updated = timezone.utcnow()