示例#1
0
    def test_bigquery_operator_extra_serialized_field_when_multiple_queries(
            self):
        with self.dag:
            BigQueryExecuteQueryOperator(
                task_id=TASK_ID,
                sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'],
            )
        serialized_dag = SerializedDAG.to_dict(self.dag)
        assert "sql" in serialized_dag["dag"]["tasks"][0]

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict[TASK_ID]
        assert getattr(simple_task, "sql") == [
            'SELECT * FROM test_table', 'SELECT * FROM test_table2'
        ]

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################

        # Check Serialized version of operator link
        assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
            {
                'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink':
                {
                    'index': 0
                }
            },
            {
                'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink':
                {
                    'index': 1
                }
            },
        ]

        # Check DeSerialized version of operator link
        assert isinstance(
            list(simple_task.operator_extra_links)[0],
            BigQueryConsoleIndexableLink)

        ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
        job_id = ['123', '45']
        ti.xcom_push(key='job_id', value=job_id)

        assert {'BigQuery Console #1', 'BigQuery Console #2'
                } == simple_task.operator_extra_link_dict.keys()

        assert 'https://console.cloud.google.com/bigquery?j=123' == simple_task.get_extra_links(
            DEFAULT_DATE, 'BigQuery Console #1')

        assert 'https://console.cloud.google.com/bigquery?j=45' == simple_task.get_extra_links(
            DEFAULT_DATE, 'BigQuery Console #2')
    def test_serialization(self):
        """Serialization and deserialization should work for every DAG and Operator."""
        dags = collect_dags()
        serialized_dags = {}
        for _, v in dags.items():
            dag = SerializedDAG.to_dict(v)
            SerializedDAG.validate_schema(dag)
            serialized_dags[v.dag_id] = dag

        # Compares with the ground truth of JSON string.
        self.validate_serialized_dag(serialized_dags['simple_dag'],
                                     serialized_simple_dag_ground_truth)
    def test_extra_serialized_field_and_multiple_operator_links(self):
        """
        Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.

        This tests also depends on GoogleLink() registered as a plugin
        in tests/plugins/test_plugin.py

        The function tests that if extra operator links are registered in plugin
        in ``operator_extra_links`` and the same is also defined in
        the Operator in ``BaseOperator.operator_extra_links``, it has the correct
        extra link.
        """
        test_date = datetime(2019, 8, 1)
        dag = DAG(dag_id='simple_dag', start_date=test_date)
        CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"])

        serialized_dag = SerializedDAG.to_dict(dag)
        self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0])

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict["simple_task"]
        self.assertEqual(getattr(simple_task, "bash_command"), ["echo", "true"])

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################
        # Check Serialized version of operator link only contains the inbuilt Op Link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
            [
                {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}},
                {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}},
            ]
        )

        # Test all the extra_links are set
        self.assertCountEqual(simple_task.extra_links, [
            'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google'])

        ti = TaskInstance(task=simple_task, execution_date=test_date)
        ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"])

        # Test Deserialized inbuilt link #1
        custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1")
        self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_1', custom_inbuilt_link)

        # Test Deserialized inbuilt link #2
        custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #2")
        self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_2', custom_inbuilt_link)

        # Test Deserialized link registered via Airflow Plugin
        google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name)
        self.assertEqual("https://www.google.com", google_link_from_plugin)
示例#4
0
    def test_deserialization_end_date(self, dag_end_date, task_end_date, expected_task_end_date):
        dag = DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1), end_date=dag_end_date)
        BaseOperator(task_id='simple_task', dag=dag, end_date=task_end_date)

        serialized_dag = SerializedDAG.to_dict(dag)
        if not task_end_date or dag_end_date <= task_end_date:
            # If dag.end_date < task.end_date -> task.end_date=dag.end_date
            # because of the logic in dag.add_task()
            assert "end_date" not in serialized_dag["dag"]["tasks"][0]
        else:
            assert "end_date" in serialized_dag["dag"]["tasks"][0]

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict["simple_task"]
        assert simple_task.end_date == expected_task_end_date
    def test_deserialization_start_date(self, dag_start_date, task_start_date, expected_task_start_date):
        dag = DAG(dag_id='simple_dag', start_date=dag_start_date)
        BaseOperator(task_id='simple_task', dag=dag, start_date=task_start_date)

        serialized_dag = SerializedDAG.to_dict(dag)
        if not task_start_date or dag_start_date >= task_start_date:
            # If dag.start_date > task.start_date -> task.start_date=dag.start_date
            # because of the logic in dag.add_task()
            self.assertNotIn("start_date", serialized_dag["dag"]["tasks"][0])
        else:
            self.assertIn("start_date", serialized_dag["dag"]["tasks"][0])

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict["simple_task"]
        self.assertEqual(simple_task.start_date, expected_task_start_date)
    def test_extra_serialized_field_and_operator_links(self):
        """
        Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links.

        This tests also depends on GoogleLink() registered as a plugin
        in tests/plugins/test_plugin.py

        The function tests that if extra operator links are registered in plugin
        in ``operator_extra_links`` and the same is also defined in
        the Operator in ``BaseOperator.operator_extra_links``, it has the correct
        extra link.
        """
        test_date = datetime(2019, 8, 1)
        dag = DAG(dag_id='simple_dag', start_date=test_date)
        CustomOperator(task_id='simple_task', dag=dag, bash_command="true")

        serialized_dag = SerializedDAG.to_dict(dag)
        assert "bash_command" in serialized_dag["dag"]["tasks"][0]

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict["simple_task"]
        assert getattr(simple_task, "bash_command") == "true"

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################
        # Check Serialized version of operator link only contains the inbuilt Op Link
        assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [{
            'tests.test_utils.mock_operators.CustomOpLink': {}
        }]

        # Test all the extra_links are set
        assert set(simple_task.extra_links) == {
            'Google Custom', 'airflow', 'github', 'google'
        }

        ti = TaskInstance(task=simple_task, execution_date=test_date)
        ti.xcom_push('search_query', "dummy_value_1")

        # Test Deserialized inbuilt link
        custom_inbuilt_link = simple_task.get_extra_links(
            test_date, CustomOpLink.name)
        assert 'http://google.com/custom_base_link?search=dummy_value_1' == custom_inbuilt_link

        # Test Deserialized link registered via Airflow Plugin
        google_link_from_plugin = simple_task.get_extra_links(
            test_date, GoogleLink.name)
        assert "https://www.google.com" == google_link_from_plugin
示例#7
0
    def test_task_params_roundtrip(self, val, expected_val):
        """
        Test that params work both on Serialized DAGs & Tasks
        """
        dag = DAG(dag_id='simple_dag')
        BaseOperator(task_id='simple_task', dag=dag, params=val, start_date=datetime(2019, 8, 1))

        serialized_dag = SerializedDAG.to_dict(dag)
        if val:
            assert "params" in serialized_dag["dag"]["tasks"][0]
        else:
            assert "params" not in serialized_dag["dag"]["tasks"][0]

        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
        deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
        assert expected_val == deserialized_simple_task.params
示例#8
0
    def test_templated_fields_exist_in_serialized_dag(self, templated_field, expected_field):
        """
        Test that templated_fields exists for all Operators in Serialized DAG

        Since we don't want to inflate arbitrary python objects (it poses a RCE/security risk etc.)
        we want check that non-"basic" objects are turned in to strings after deserializing.
        """

        dag = DAG("test_serialized_template_fields", start_date=datetime(2019, 8, 1))
        with dag:
            BashOperator(task_id="test", bash_command=templated_field)

        serialized_dag = SerializedDAG.to_dict(dag)
        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
        deserialized_test_task = deserialized_dag.task_dict["test"]
        assert expected_field == getattr(deserialized_test_task, "bash_command")
示例#9
0
    def test_task_group_serialization(self):
        """
        Test TaskGroup serialization/deserialization.
        """
        from airflow.operators.dummy import DummyOperator
        from airflow.utils.task_group import TaskGroup

        execution_date = datetime(2020, 1, 1)
        with DAG("test_task_group_serialization", start_date=execution_date) as dag:
            task1 = DummyOperator(task_id="task1")
            with TaskGroup("group234") as group234:
                _ = DummyOperator(task_id="task2")

                with TaskGroup("group34") as group34:
                    _ = DummyOperator(task_id="task3")
                    _ = DummyOperator(task_id="task4")

            task5 = DummyOperator(task_id="task5")
            task1 >> group234
            group34 >> task5

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))

        assert serialized_dag.task_group.children
        assert serialized_dag.task_group.children.keys() == dag.task_group.children.keys()

        def check_task_group(node):
            try:
                children = node.children.values()
            except AttributeError:
                # Round-trip serialization and check the result
                expected_serialized = SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id))
                expected_deserialized = SerializedBaseOperator.deserialize_operator(expected_serialized)
                expected_dict = SerializedBaseOperator.serialize_operator(expected_deserialized)
                assert node
                assert SerializedBaseOperator.serialize_operator(node) == expected_dict
                return

            for child in children:
                check_task_group(child)

        check_task_group(serialized_dag.task_group)
    def test_dag_params_roundtrip(self, val, expected_val):
        """
        Test that params work both on Serialized DAGs & Tasks
        """
        dag = DAG(dag_id='simple_dag', params=val)
        BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

        serialized_dag = SerializedDAG.to_dict(dag)
        if val:
            self.assertIn("params", serialized_dag["dag"])
        else:
            self.assertNotIn("params", serialized_dag["dag"])

        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
        deserialized_simple_task = deserialized_dag.task_dict["simple_task"]
        self.assertEqual(expected_val, deserialized_dag.params)
        self.assertEqual(expected_val, deserialized_simple_task.params)
示例#11
0
    def test_console_extra_link_serialized_field(self):
        with self.dag:
            training_op = MLEngineStartTrainingJobOperator(
                **self.TRAINING_DEFAULT_ARGS)
        serialized_dag = SerializedDAG.to_dict(self.dag)
        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']]

        # Check Serialized version of operator link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
            [{
                "airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink":
                {}
            }],
        )

        # Check DeSerialized version of operator link
        self.assertIsInstance(
            list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink)

        job_id = self.TRAINING_DEFAULT_ARGS['job_id']
        project_id = self.TRAINING_DEFAULT_ARGS['project_id']
        gcp_metadata = {
            "job_id": job_id,
            "project_id": project_id,
        }

        ti = TaskInstance(
            task=training_op,
            execution_date=DEFAULT_DATE,
        )
        ti.xcom_push(key='gcp_metadata', value=gcp_metadata)

        self.assertEqual(
            f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
            simple_task.get_extra_links(DEFAULT_DATE,
                                        AIPlatformConsoleLink.name),
        )

        self.assertEqual(
            '',
            simple_task.get_extra_links(datetime.datetime(2019, 1, 1),
                                        AIPlatformConsoleLink.name),
        )
示例#12
0
    def test_bigquery_operator_extra_serialized_field_when_single_query(self):
        with self.dag:
            BigQueryExecuteQueryOperator(
                task_id=TASK_ID,
                sql='SELECT * FROM test_table',
            )
        serialized_dag = SerializedDAG.to_dict(self.dag)
        self.assertIn("sql", serialized_dag["dag"]["tasks"][0])

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict[TASK_ID]
        self.assertEqual(getattr(simple_task, "sql"),
                         'SELECT * FROM test_table')

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################

        # Check Serialized version of operator link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
            [{
                'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink':
                {}
            }],
        )

        # Check DeSerialized version of operator link
        self.assertIsInstance(
            list(simple_task.operator_extra_links)[0], BigQueryConsoleLink)

        ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
        ti.xcom_push('job_id', 12345)

        # check for positive case
        url = simple_task.get_extra_links(DEFAULT_DATE,
                                          BigQueryConsoleLink.name)
        self.assertEqual(url,
                         'https://console.cloud.google.com/bigquery?j=12345')

        # check for negative case
        url2 = simple_task.get_extra_links(datetime(2017, 1, 2),
                                           BigQueryConsoleLink.name)
        self.assertEqual(url2, '')
示例#13
0
    def test_dag_on_success_callback_roundtrip(self, passed_success_callback, expected_value):
        """
        Test that when on_success_callback is passed to the DAG, has_on_success_callback is stored
        in Serialized JSON blob. And when it is de-serialized dag.has_on_success_callback is set to True.

        When the callback is not set, has_on_success_callback should not be stored in Serialized blob
        and so default to False on de-serialization
        """
        dag = DAG(dag_id='test_dag_on_success_callback_roundtrip', **passed_success_callback)
        BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))

        serialized_dag = SerializedDAG.to_dict(dag)
        if expected_value:
            assert "has_on_success_callback" in serialized_dag["dag"]
        else:
            assert "has_on_success_callback" not in serialized_dag["dag"]

        deserialized_dag = SerializedDAG.from_dict(serialized_dag)

        assert deserialized_dag.has_on_success_callback is expected_value
示例#14
0
    def test_edge_info_serialization(self):
        """
        Tests edge_info serialization/deserialization.
        """
        from airflow.operators.dummy import DummyOperator
        from airflow.utils.edgemodifier import Label

        with DAG("test_edge_info_serialization", start_date=datetime(2020, 1, 1)) as dag:
            task1 = DummyOperator(task_id="task1")
            task2 = DummyOperator(task_id="task2")
            task1 >> Label("test label") >> task2  # pylint: disable=W0106

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))

        assert serialized_dag.edge_info == dag.edge_info
示例#15
0
    def __init__(self, dag: DAG):
        self.dag_id = dag.dag_id
        self.fileloc = dag.fileloc
        self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
        self.last_updated = timezone.utcnow()

        dag_data = SerializedDAG.to_dict(dag)
        dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8")

        self.dag_hash = hashlib.md5(dag_data_json).hexdigest()

        if COMPRESS_SERIALIZED_DAGS:
            self._data = None
            self._data_compressed = zlib.compress(dag_data_json)
        else:
            self._data = dag_data
            self._data_compressed = None

        # serve as cache so no need to decompress and load, when accessing data field
        # when COMPRESS_SERIALIZED_DAGS is True
        self.__data_cache = dag_data
示例#16
0
    def test_dagrun_update_state_with_handle_callback_failure(self):
        def on_failure_callable(context):
            self.assertEqual(
                context['dag_run'].dag_id,
                'test_dagrun_update_state_with_handle_callback_failure')

        dag = DAG(
            dag_id='test_dagrun_update_state_with_handle_callback_failure',
            start_date=datetime.datetime(2017, 1, 1),
            on_failure_callback=on_failure_callable,
        )
        dag_task1 = DummyOperator(task_id='test_state_succeeded1', dag=dag)
        dag_task2 = DummyOperator(task_id='test_state_failed2', dag=dag)
        dag_task1.set_downstream(dag_task2)

        initial_task_states = {
            'test_state_succeeded1': State.SUCCESS,
            'test_state_failed2': State.FAILED,
        }

        # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))

        dag_run = self.create_dag_run(dag=dag,
                                      state=State.RUNNING,
                                      task_states=initial_task_states)

        _, callback = dag_run.update_state(execute_callbacks=False)
        self.assertEqual(State.FAILED, dag_run.state)
        # Callbacks are not added until handle_callback = False is passed to dag_run.update_state()

        assert callback == DagCallbackRequest(
            full_filepath=dag_run.dag.fileloc,
            dag_id="test_dagrun_update_state_with_handle_callback_failure",
            execution_date=dag_run.execution_date,
            is_failure_callback=True,
            msg="task_failure",
        )
示例#17
0
 def test_deserialization_with_dag_context(self):
     with DAG(dag_id='simple_dag', start_date=datetime(2019, 8, 1, tzinfo=timezone.utc)) as dag:
         BaseOperator(task_id='simple_task')
         # should not raise RuntimeError: dictionary changed size during iteration
         SerializedDAG.to_dict(dag)
示例#18
0
 def __init__(self, dag):
     self.dag_id = dag.dag_id
     self.fileloc = dag.full_filepath
     self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
     self.data = SerializedDAG.to_dict(dag)
     self.last_updated = timezone.utcnow()