def test_roundtrip_provider_example_dags(self):
        dags = collect_dags([
            "airflow/providers/*/example_dags",
            "airflow/providers/*/*/example_dags",
        ])

        # Verify deserialized DAGs.
        for dag in dags.values():
            serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
            self.validate_deserialized_dag(serialized_dag, dag)
    def test_task_group_serialization(self):
        """
        Test TaskGroup serialization/deserialization.
        """
        from airflow.operators.dummy_operator import DummyOperator
        from airflow.utils.task_group import TaskGroup

        execution_date = datetime(2020, 1, 1)
        with DAG("test_task_group_serialization",
                 start_date=execution_date) as dag:
            task1 = DummyOperator(task_id="task1")
            with TaskGroup("group234") as group234:
                _ = DummyOperator(task_id="task2")

                with TaskGroup("group34") as group34:
                    _ = DummyOperator(task_id="task3")
                    _ = DummyOperator(task_id="task4")

            task5 = DummyOperator(task_id="task5")
            task1 >> group234
            group34 >> task5

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(
            SerializedDAG.serialize_dag(dag))

        assert serialized_dag.task_group.children
        assert serialized_dag.task_group.children.keys(
        ) == dag.task_group.children.keys()

        def check_task_group(node):
            try:
                children = node.children.values()
            except AttributeError:
                # Round-trip serialization and check the result
                expected_serialized = SerializedBaseOperator.serialize_operator(
                    dag.get_task(node.task_id))
                expected_deserialized = SerializedBaseOperator.deserialize_operator(
                    expected_serialized)
                expected_dict = SerializedBaseOperator.serialize_operator(
                    expected_deserialized)
                assert node
                assert SerializedBaseOperator.serialize_operator(
                    node) == expected_dict
                return

            for child in children:
                check_task_group(child)

        check_task_group(serialized_dag.task_group)
示例#3
0
    def test_edge_info_serialization(self):
        """
        Tests edge_info serialization/deserialization.
        """
        from airflow.operators.dummy import DummyOperator
        from airflow.utils.edgemodifier import Label

        with DAG("test_edge_info_serialization", start_date=datetime(2020, 1, 1)) as dag:
            task1 = DummyOperator(task_id="task1")
            task2 = DummyOperator(task_id="task2")
            task1 >> Label("test label") >> task2  # pylint: disable=W0106

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))

        assert serialized_dag.edge_info == dag.edge_info
示例#4
0
def serialize_subprocess(queue, dag_folder):
    """Validate pickle in a subprocess."""
    dags = collect_dags(dag_folder)
    for dag in dags.values():
        queue.put(SerializedDAG.to_json(dag))
    queue.put(None)