Пример #1
0
class SerializedDAG(DAG, BaseSerialization):
    """
    A JSON serializable representation of DAG.

    A stringified DAG can only be used in the scope of scheduler and webserver, because fields
    that are not serializable, such as functions and customer defined classes, are casted to
    strings.

    Compared with SimpleDAG: SerializedDAG contains all information for webserver.
    Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
    not pickle-able. SerializedDAG works for all DAGs.
    """

    _decorated_fields = {
        'schedule_interval', 'default_args', '_access_control'
    }

    @staticmethod
    def __get_constructor_defaults():  # pylint: disable=no-method-argument
        param_to_attr = {
            'concurrency': '_concurrency',
            'description': '_description',
            'default_view': '_default_view',
            'access_control': '_access_control',
        }
        return {
            param_to_attr.get(k, k): v
            for k, v in signature(DAG).parameters.items()
            if v.default is not v.empty
        }

    _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__()  # type: ignore
    del __get_constructor_defaults

    _json_schema = load_dag_schema()

    @classmethod
    def serialize_dag(cls, dag: DAG) -> dict:
        """Serializes a DAG into a JSON object.
        """
        serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields)

        serialize_dag["tasks"] = [
            cls._serialize(task) for _, task in dag.task_dict.items()
        ]
        return serialize_dag

    @classmethod
    def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
        """Deserializes a DAG from a JSON object.
        """
        dag = SerializedDAG(dag_id=encoded_dag['_dag_id'])

        for k, v in encoded_dag.items():
            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "tasks":
                v = {
                    task["task_id"]:
                    SerializedBaseOperator.deserialize_operator(task)
                    for task in v
                }
                k = "task_dict"
            elif k == "timezone":
                v = cls._deserialize_timezone(v)
            elif k in {"dagrun_timeout"}:
                v = cls._deserialize_timedelta(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k in cls._decorated_fields:
                v = cls._deserialize(v)
            # else use v as it is

            setattr(dag, k, v)

        keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys(
        ) - cls._CONSTRUCTOR_PARAMS.keys()
        for k in keys_to_set_none:
            setattr(dag, k, None)

        setattr(dag, 'full_filepath', dag.fileloc)
        for task in dag.task_dict.values():
            task.dag = dag
            serializable_task: BaseOperator = task

            for date_attr in ["start_date", "end_date"]:
                if getattr(serializable_task, date_attr) is None:
                    setattr(serializable_task, date_attr,
                            getattr(dag, date_attr))

            if serializable_task.subdag is not None:
                setattr(serializable_task.subdag, 'parent_dag', dag)
                serializable_task.subdag.is_subdag = True

            for task_id in serializable_task.downstream_task_ids:
                # Bypass set_upstream etc here - it does more than we want
                # noinspection PyProtectedMember
                dag.task_dict[task_id]._upstream_task_ids.add(task_id)  # pylint: disable=protected-access

        return dag

    @classmethod
    def to_dict(cls, var: Any) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var.
        """
        json_dict = {
            "__version": cls.SERIALIZER_VERSION,
            "dag": cls.serialize_dag(var)
        }

        # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
        cls.validate_schema(json_dict)
        return json_dict

    @classmethod
    def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG':
        """Deserializes a python dict in to the DAG and operators it contains."""
        ver = serialized_obj.get('__version', '<not present>')
        if ver != cls.SERIALIZER_VERSION:
            raise ValueError(
                "Unsure how to deserialize version {!r}".format(ver))
        return cls.deserialize_dag(serialized_obj['dag'])
Пример #2
0
class SerializedDAG(DAG, BaseSerialization):
    """
    A JSON serializable representation of DAG.

    A stringified DAG can only be used in the scope of scheduler and webserver, because fields
    that are not serializable, such as functions and customer defined classes, are casted to
    strings.

    Compared with SimpleDAG: SerializedDAG contains all information for webserver.
    Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
    not pickle-able. SerializedDAG works for all DAGs.
    """

    _decorated_fields = {'schedule_interval', 'default_args', '_access_control'}

    @staticmethod
    def __get_constructor_defaults():  # pylint: disable=no-method-argument
        param_to_attr = {
            'concurrency': '_concurrency',
            'description': '_description',
            'default_view': '_default_view',
            'access_control': '_access_control',
        }
        return {
            param_to_attr.get(k, k): v.default
            for k, v in signature(DAG.__init__).parameters.items()
            if v.default is not v.empty
        }

    _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__()  # type: ignore
    del __get_constructor_defaults

    _json_schema = load_dag_schema()

    @classmethod
    def serialize_dag(cls, dag: DAG) -> dict:
        """Serializes a DAG into a JSON object."""
        try:
            serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields)

            serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
            serialize_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)

            # Edge info in the JSON exactly matches our internal structure
            serialize_dag["edge_info"] = dag.edge_info

            # has_on_*_callback are only stored if the value is True, as the default is False
            if dag.has_on_success_callback:
                serialize_dag['has_on_success_callback'] = True
            if dag.has_on_failure_callback:
                serialize_dag['has_on_failure_callback'] = True
            return serialize_dag
        except SerializationError:
            raise
        except Exception:
            raise SerializationError(f'Failed to serialize dag {dag.dag_id!r}')

    @classmethod
    def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
        """Deserializes a DAG from a JSON object."""
        dag = SerializedDAG(dag_id=encoded_dag['_dag_id'])

        for k, v in encoded_dag.items():
            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "tasks":
                # pylint: disable=protected-access
                SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links
                # pylint: enable=protected-access
                v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
                k = "task_dict"
            elif k == "timezone":
                v = cls._deserialize_timezone(v)
            elif k in {"dagrun_timeout"}:
                v = cls._deserialize_timedelta(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "edge_info":
                # Value structure matches exactly
                pass
            elif k in cls._decorated_fields:
                v = cls._deserialize(v)
            # else use v as it is

            setattr(dag, k, v)

        # Set _task_group
        # pylint: disable=protected-access
        if "_task_group" in encoded_dag:
            dag._task_group = SerializedTaskGroup.deserialize_task_group(  # type: ignore
                encoded_dag["_task_group"], None, dag.task_dict
            )
        else:
            # This must be old data that had no task_group. Create a root TaskGroup and add
            # all tasks to it.
            dag._task_group = TaskGroup.create_root(dag)
            for task in dag.tasks:
                dag.task_group.add(task)
        # pylint: enable=protected-access

        # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
        if "has_on_success_callback" in encoded_dag:
            dag.has_on_success_callback = True
        if "has_on_failure_callback" in encoded_dag:
            dag.has_on_failure_callback = True

        keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
        for k in keys_to_set_none:
            setattr(dag, k, None)

        setattr(dag, 'full_filepath', dag.fileloc)
        for task in dag.task_dict.values():
            task.dag = dag
            serializable_task: BaseOperator = task

            for date_attr in ["start_date", "end_date"]:
                if getattr(serializable_task, date_attr) is None:
                    setattr(serializable_task, date_attr, getattr(dag, date_attr))

            if serializable_task.subdag is not None:
                setattr(serializable_task.subdag, 'parent_dag', dag)
                serializable_task.subdag.is_subdag = True

            for task_id in serializable_task.downstream_task_ids:
                # Bypass set_upstream etc here - it does more than we want
                # noqa: E501 # pylint: disable=protected-access
                dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id)

        return dag

    @classmethod
    def to_dict(cls, var: Any) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var."""
        json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}

        # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
        cls.validate_schema(json_dict)
        return json_dict

    @classmethod
    def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG':
        """Deserializes a python dict in to the DAG and operators it contains."""
        ver = serialized_obj.get('__version', '<not present>')
        if ver != cls.SERIALIZER_VERSION:
            raise ValueError(f"Unsure how to deserialize version {ver!r}")
        return cls.deserialize_dag(serialized_obj['dag'])
Пример #3
0
class SerializedDAG(DAG, BaseSerialization):
    """
    A JSON serializable representation of DAG.

    A stringified DAG can only be used in the scope of scheduler and webserver, because fields
    that are not serializable, such as functions and customer defined classes, are casted to
    strings.

    Compared with SimpleDAG: SerializedDAG contains all information for webserver.
    Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are
    not pickle-able. SerializedDAG works for all DAGs.
    """

    _decorated_fields = {'schedule_interval', 'default_args', '_access_control'}

    @staticmethod
    def __get_constructor_defaults():
        param_to_attr = {
            'max_active_tasks': '_max_active_tasks',
            'description': '_description',
            'default_view': '_default_view',
            'access_control': '_access_control',
        }
        return {
            param_to_attr.get(k, k): v.default
            for k, v in signature(DAG.__init__).parameters.items()
            if v.default is not v.empty
        }

    _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__()  # type: ignore
    del __get_constructor_defaults

    _json_schema = load_dag_schema()

    @classmethod
    def serialize_dag(cls, dag: DAG) -> dict:
        """Serializes a DAG into a JSON object."""
        try:
            serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)

            # If schedule_interval is backed by timetable, serialize only
            # timetable; vice versa for a timetable backed by schedule_interval.
            if dag.timetable.summary == dag.schedule_interval:
                del serialized_dag["schedule_interval"]
            else:
                del serialized_dag["timetable"]

            serialized_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
            serialized_dag["dag_dependencies"] = [
                vars(t)
                for t in (SerializedBaseOperator.detect_dependencies(task) for task in dag.task_dict.values())
                if t is not None
            ]
            serialized_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)

            # Edge info in the JSON exactly matches our internal structure
            serialized_dag["edge_info"] = dag.edge_info
            serialized_dag["params"] = cls._serialize_params_dict(dag.params)

            # has_on_*_callback are only stored if the value is True, as the default is False
            if dag.has_on_success_callback:
                serialized_dag['has_on_success_callback'] = True
            if dag.has_on_failure_callback:
                serialized_dag['has_on_failure_callback'] = True
            return serialized_dag
        except SerializationError:
            raise
        except Exception as e:
            raise SerializationError(f'Failed to serialize DAG {dag.dag_id!r}: {e}')

    @classmethod
    def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
        """Deserializes a DAG from a JSON object."""
        dag = SerializedDAG(dag_id=encoded_dag['_dag_id'])

        for k, v in encoded_dag.items():
            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "tasks":

                SerializedBaseOperator._load_operator_extra_links = cls._load_operator_extra_links

                v = {task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v}
                k = "task_dict"
            elif k == "timezone":
                v = cls._deserialize_timezone(v)
            elif k == "dagrun_timeout":
                v = cls._deserialize_timedelta(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "edge_info":
                # Value structure matches exactly
                pass
            elif k == "timetable":
                v = _decode_timetable(v)
            elif k in cls._decorated_fields:
                v = cls._deserialize(v)
            elif k == "params":
                v = cls._deserialize_params_dict(v)
            # else use v as it is

            setattr(dag, k, v)

        # A DAG is always serialized with only one of schedule_interval and
        # timetable. This back-populates the other to ensure the two attributes
        # line up correctly on the DAG instance.
        if "timetable" in encoded_dag:
            dag.schedule_interval = dag.timetable.summary
        else:
            dag.timetable = create_timetable(dag.schedule_interval, dag.timezone)

        # Set _task_group

        if "_task_group" in encoded_dag:
            dag._task_group = SerializedTaskGroup.deserialize_task_group(  # type: ignore
                encoded_dag["_task_group"], None, dag.task_dict
            )
        else:
            # This must be old data that had no task_group. Create a root TaskGroup and add
            # all tasks to it.
            dag._task_group = TaskGroup.create_root(dag)
            for task in dag.tasks:
                dag.task_group.add(task)

        # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
        if "has_on_success_callback" in encoded_dag:
            dag.has_on_success_callback = True
        if "has_on_failure_callback" in encoded_dag:
            dag.has_on_failure_callback = True

        keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
        for k in keys_to_set_none:
            setattr(dag, k, None)

        for task in dag.task_dict.values():
            task.dag = dag
            serializable_task: BaseOperator = task

            for date_attr in ["start_date", "end_date"]:
                if getattr(serializable_task, date_attr) is None:
                    setattr(serializable_task, date_attr, getattr(dag, date_attr))

            if serializable_task.subdag is not None:
                setattr(serializable_task.subdag, 'parent_dag', dag)

            if isinstance(task, MappedOperator):
                for d in (task.mapped_kwargs, task.partial_kwargs):
                    for k, v in d.items():
                        if not isinstance(v, _XcomRef):
                            continue

                        d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)

            for task_id in serializable_task.downstream_task_ids:
                # Bypass set_upstream etc here - it does more than we want
                dag.task_dict[task_id].upstream_task_ids.add(serializable_task.task_id)

        return dag

    @classmethod
    def to_dict(cls, var: Any) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var."""
        json_dict = {"__version": cls.SERIALIZER_VERSION, "dag": cls.serialize_dag(var)}

        # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
        cls.validate_schema(json_dict)
        return json_dict

    @classmethod
    def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG':
        """Deserializes a python dict in to the DAG and operators it contains."""
        ver = serialized_obj.get('__version', '<not present>')
        if ver != cls.SERIALIZER_VERSION:
            raise ValueError(f"Unsure how to deserialize version {ver!r}")
        return cls.deserialize_dag(serialized_obj['dag'])