Example #1
0
    def dag(self):
        """The DAG deserialized from the ``data`` column"""
        from airflow.serialization import SerializedDAG  # noqa # pylint: disable=redefined-outer-name

        if isinstance(self.data, dict):
            dag = SerializedDAG.from_dict(self.data)  # type: Any
        else:
            # noinspection PyTypeChecker
            dag = SerializedDAG.from_json(self.data)
        return dag
Example #2
0
    def test_serialization(self):
        """Serialization and deserialization should work for every DAG and Operator."""
        dags = collect_dags()
        serialized_dags = {}
        for _, v in dags.items():
            dag = SerializedDAG.to_dict(v)
            SerializedDAG.validate_schema(dag)
            serialized_dags[v.dag_id] = dag

        # Compares with the ground truth of JSON string.
        self.validate_serialized_dag(serialized_dags['simple_dag'],
                                     serialized_simple_dag_ground_truth)
Example #3
0
    def __init__(self, dag: 'DAG'):
        from airflow.serialization import SerializedDAG  # noqa # pylint: disable=redefined-outer-name

        self.dag_id = dag.dag_id
        self.fileloc = dag.full_filepath
        self.fileloc_hash = self.dag_fileloc_hash(self.fileloc)
        self.data = SerializedDAG.to_dict(dag)
        self.last_updated = timezone.utcnow()
Example #4
0
    def test_deserialization_schedule_interval(self,
                                               serialized_schedule_interval,
                                               expected):
        serialized = {
            "__version": 1,
            "dag": {
                "default_args": {
                    "__type": "dict",
                    "__var": {}
                },
                "params": {},
                "_dag_id": "simple_dag",
                "fileloc": __file__,
                "tasks": [],
                "timezone": "UTC",
                "schedule_interval": serialized_schedule_interval,
            },
        }

        SerializedDAG.validate_schema(serialized)

        dag = SerializedDAG.from_dict(serialized)

        self.assertEqual(dag.schedule_interval, expected)
Example #5
0
    def test_deserialization(self):
        """A serialized DAG can be deserialized in another process."""
        queue = multiprocessing.Queue()
        proc = multiprocessing.Process(target=serialize_subprocess,
                                       args=(queue, ))
        proc.daemon = True
        proc.start()

        stringified_dags = {}
        while True:
            v = queue.get()
            if v is None:
                break
            dag = SerializedDAG.from_json(v)
            self.assertTrue(isinstance(dag, DAG))
            stringified_dags[dag.dag_id] = dag

        dags = collect_dags()
        self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))

        # Verify deserialized DAGs.
        example_skip_dag = stringified_dags['example_skip_dag']
        skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
        self.validate_deserialized_task(skip_operator_1_task,
                                        'DummySkipOperator', '#e8b7e4', '#000')

        # Verify that the DAG object has 'full_filepath' attribute
        # and is equal to fileloc
        self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
        self.assertEqual(example_skip_dag.full_filepath,
                         example_skip_dag.fileloc)

        example_subdag_operator = stringified_dags['example_subdag_operator']
        section_1_task = example_subdag_operator.task_dict['section-1']
        self.validate_deserialized_task(section_1_task,
                                        SubDagOperator.__name__,
                                        SubDagOperator.ui_color,
                                        SubDagOperator.ui_fgcolor)

        simple_dag = stringified_dags['simple_dag']
        custom_task = simple_dag.task_dict['custom_task']
        self.validate_operator_extra_links(custom_task)
    def deserialize_operator(cls, encoded_op: dict) -> BaseOperator:
        """Deserializes an operator from a JSON object.
        """
        from airflow.serialization import SerializedDAG
        from airflow.plugins_manager import operator_extra_links

        op = SerializedBaseOperator(task_id=encoded_op['task_id'])

        # Extra Operator Links
        op_extra_links_from_plugin = {}

        for ope in operator_extra_links:
            for operator in ope.operators:
                if operator.__name__ == encoded_op["_task_type"] and \
                        operator.__module__ == encoded_op["_task_module"]:
                    op_extra_links_from_plugin.update({ope.name: ope})

        setattr(op, "operator_extra_links",
                list(op_extra_links_from_plugin.values()))

        for k, v in encoded_op.items():

            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "subdag":
                v = SerializedDAG.deserialize_dag(v)
            elif k in {"retry_delay", "execution_timeout"}:
                v = cls._deserialize_timedelta(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k in cls._decorated_fields or k not in op._serialized_fields:  # noqa: E501; # pylint: disable=protected-access
                v = cls._deserialize(v)
            # else use v as it is

            setattr(op, k, v)

        # pylint: disable=protected-access
        for k in op._serialized_fields - encoded_op.keys(
        ) - cls._CONSTRUCTOR_PARAMS.keys():
            setattr(op, k, None)

        return op
Example #7
0
    def test_roundtrip_relativedelta(self, val, expected):
        serialized = SerializedDAG._serialize(val)
        self.assertDictEqual(serialized, expected)

        round_tripped = SerializedDAG._deserialize(serialized)
        self.assertEqual(val, round_tripped)
Example #8
0
def serialize_subprocess(queue):
    """Validate pickle in a subprocess."""
    dags = collect_dags()
    for dag in dags.values():
        queue.put(SerializedDAG.to_json(dag))
    queue.put(None)