示例#1
0
 def dag(self):
     """The DAG deserialized from the ``data`` column"""
     if isinstance(self.data, dict):
         dag = SerializedDAG.from_dict(self.data)  # type: Any
     else:
         dag = SerializedDAG.from_json(self.data)  # noqa
     return dag
示例#2
0
    def test_deserialization_across_process(self):
        """A serialized DAG can be deserialized in another process."""

        # Since we need to parse the dags twice here (once in the subprocess,
        # and once here to get a DAG to compare to) we don't want to load all
        # dags.
        queue = multiprocessing.Queue()
        proc = multiprocessing.Process(target=serialize_subprocess, args=(queue, "airflow/example_dags"))
        proc.daemon = True
        proc.start()

        stringified_dags = {}
        while True:
            v = queue.get()
            if v is None:
                break
            dag = SerializedDAG.from_json(v)
            assert isinstance(dag, DAG)
            stringified_dags[dag.dag_id] = dag

        dags = collect_dags("airflow/example_dags")
        assert set(stringified_dags.keys()) == set(dags.keys())

        # Verify deserialized DAGs.
        for dag_id in stringified_dags:
            self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
示例#3
0
 def dag(self):
     """The DAG deserialized from the ``data`` column"""
     if isinstance(self.data, dict):
         dag = SerializedDAG.from_dict(self.data)  # type: Any
     else:
         # noinspection PyTypeChecker
         dag = SerializedDAG.from_json(self.data)
     return dag
示例#4
0
    def dag(self):
        """The DAG deserialized from the ``data`` column"""
        SerializedDAG._load_operator_extra_links = self.load_op_links

        if isinstance(self.data, dict):
            dag = SerializedDAG.from_dict(self.data)  # type: Any
        else:
            dag = SerializedDAG.from_json(self.data)
        return dag
示例#5
0
    def dag(self):
        """The DAG deserialized from the ``data`` column"""
        SerializedDAG._load_operator_extra_links = self.load_op_links  # pylint: disable=protected-access

        if isinstance(self.data, dict):
            dag = SerializedDAG.from_dict(self.data)  # type: Any
        else:
            dag = SerializedDAG.from_json(self.data)  # noqa
        return dag
    def test_roundtrip_provider_example_dags(self):
        dags = collect_dags([
            "airflow/providers/*/example_dags",
            "airflow/providers/*/*/example_dags",
        ])

        # Verify deserialized DAGs.
        for dag in dags.values():
            serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
            self.validate_deserialized_dag(serialized_dag, dag)
    def test_task_group_serialization(self):
        """
        Test TaskGroup serialization/deserialization.
        """
        from airflow.operators.dummy_operator import DummyOperator
        from airflow.utils.task_group import TaskGroup

        execution_date = datetime(2020, 1, 1)
        with DAG("test_task_group_serialization",
                 start_date=execution_date) as dag:
            task1 = DummyOperator(task_id="task1")
            with TaskGroup("group234") as group234:
                _ = DummyOperator(task_id="task2")

                with TaskGroup("group34") as group34:
                    _ = DummyOperator(task_id="task3")
                    _ = DummyOperator(task_id="task4")

            task5 = DummyOperator(task_id="task5")
            task1 >> group234
            group34 >> task5

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(
            SerializedDAG.serialize_dag(dag))

        assert serialized_dag.task_group.children
        assert serialized_dag.task_group.children.keys(
        ) == dag.task_group.children.keys()

        def check_task_group(node):
            try:
                children = node.children.values()
            except AttributeError:
                # Round-trip serialization and check the result
                expected_serialized = SerializedBaseOperator.serialize_operator(
                    dag.get_task(node.task_id))
                expected_deserialized = SerializedBaseOperator.deserialize_operator(
                    expected_serialized)
                expected_dict = SerializedBaseOperator.serialize_operator(
                    expected_deserialized)
                assert node
                assert SerializedBaseOperator.serialize_operator(
                    node) == expected_dict
                return

            for child in children:
                check_task_group(child)

        check_task_group(serialized_dag.task_group)
    def test_deserialization(self):
        """A serialized DAG can be deserialized in another process."""
        queue = multiprocessing.Queue()
        proc = multiprocessing.Process(target=serialize_subprocess,
                                       args=(queue, ))
        proc.daemon = True
        proc.start()

        stringified_dags = {}
        while True:
            v = queue.get()
            if v is None:
                break
            dag = SerializedDAG.from_json(v)
            self.assertTrue(isinstance(dag, DAG))
            stringified_dags[dag.dag_id] = dag

        dags = collect_dags()
        self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))

        # Verify deserialized DAGs.
        for dag_id in stringified_dags:
            self.validate_deserialized_dag(stringified_dags[dag_id],
                                           dags[dag_id])

        example_skip_dag = stringified_dags['example_skip_dag']
        skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
        self.validate_deserialized_task(skip_operator_1_task,
                                        'DummySkipOperator', '#e8b7e4', '#000')

        # Verify that the DAG object has 'full_filepath' attribute
        # and is equal to fileloc
        self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
        self.assertEqual(example_skip_dag.full_filepath,
                         example_skip_dag.fileloc)

        example_subdag_operator = stringified_dags['example_subdag_operator']
        section_1_task = example_subdag_operator.task_dict['section-1']
        self.validate_deserialized_task(section_1_task,
                                        SubDagOperator.__name__,
                                        SubDagOperator.ui_color,
                                        SubDagOperator.ui_fgcolor)
示例#9
0
    def test_edge_info_serialization(self):
        """
        Tests edge_info serialization/deserialization.
        """
        from airflow.operators.dummy import DummyOperator
        from airflow.utils.edgemodifier import Label

        with DAG("test_edge_info_serialization", start_date=datetime(2020, 1, 1)) as dag:
            task1 = DummyOperator(task_id="task1")
            task2 = DummyOperator(task_id="task2")
            task1 >> Label("test label") >> task2  # pylint: disable=W0106

        dag_dict = SerializedDAG.to_dict(dag)
        SerializedDAG.validate_schema(dag_dict)
        json_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
        self.validate_deserialized_dag(json_dag, dag)

        serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))

        assert serialized_dag.edge_info == dag.edge_info