Ejemplo n.º 1
0
def send_email(to: Union[List[str], Iterable[str]],
               subject: str,
               html_content: str,
               files=None,
               dryrun=False,
               cc=None,
               bcc=None,
               mime_subtype='mixed',
               mime_charset='utf-8',
               **kwargs):
    """
    Send email using backend specified in EMAIL_BACKEND.
    """
    backend = conf.getimport('email', 'EMAIL_BACKEND')
    to_list = get_email_address_list(to)
    to_comma_seperated = ", ".join(to_list)

    return backend(to_comma_seperated,
                   subject,
                   html_content,
                   files=files,
                   dryrun=dryrun,
                   cc=cc,
                   bcc=bcc,
                   mime_subtype=mime_subtype,
                   mime_charset=mime_charset,
                   **kwargs)
Ejemplo n.º 2
0
def get_hostname():
    """
    Fetch the hostname using the callable from the config or using
    `socket.getfqdn` as a fallback.
    """
    # First we attempt to fetch the callable path from the config.
    try:
        callable_path = conf.get('core', 'hostname_callable')
    except AirflowConfigException:
        callable_path = None

    # Then we handle the case when the config is missing or empty. This is the
    # default behavior.
    if not callable_path:
        return socket.getfqdn()

    # Since we have a callable path, we try to import and run it next.
    if ":" in callable_path:
        module_path, attr_name = callable_path.split(':')
        module = importlib.import_module(module_path)
        callable = getattr(module, attr_name)
        return callable()
    else:
        return conf.getimport('core',
                              'hostname_callable',
                              fallback='socket.getfqdn')()
Ejemplo n.º 3
0
def send_email(
    to: Union[List[str], Iterable[str]],
    subject: str,
    html_content: str,
    files: Optional[List[str]] = None,
    dryrun: bool = False,
    cc: Optional[Union[str, Iterable[str]]] = None,
    bcc: Optional[Union[str, Iterable[str]]] = None,
    mime_subtype: str = 'mixed',
    mime_charset: str = 'utf-8',
    conn_id: Optional[str] = None,
    **kwargs,
):
    """Send email using backend specified in EMAIL_BACKEND."""
    backend = conf.getimport('email', 'EMAIL_BACKEND')
    backend_conn_id = conn_id or conf.get("email", "EMAIL_CONN_ID")
    to_list = get_email_address_list(to)
    to_comma_separated = ", ".join(to_list)

    return backend(
        to_comma_separated,
        subject,
        html_content,
        files=files,
        dryrun=dryrun,
        cc=cc,
        bcc=bcc,
        mime_subtype=mime_subtype,
        mime_charset=mime_charset,
        conn_id=backend_conn_id,
        **kwargs,
    )
Ejemplo n.º 4
0
def initialize_secrets_backends() -> List[BaseSecretsBackend]:
    """
    * import secrets backend classes
    * instantiate them and return them in a list
    """
    secrets_backend_cls = conf.getimport(section=CONFIG_SECTION, key='backend')
    backend_list = []

    if secrets_backend_cls:
        try:
            alternative_secrets_config_dict = json.loads(
                conf.get(section=CONFIG_SECTION,
                         key='backend_kwargs',
                         fallback='{}'))
        except JSONDecodeError:
            alternative_secrets_config_dict = {}

        backend_list.append(
            secrets_backend_cls(**alternative_secrets_config_dict))

    for class_name in DEFAULT_SECRETS_SEARCH_PATH:
        secrets_backend_cls = import_string(class_name)
        backend_list.append(secrets_backend_cls())

    return backend_list
Ejemplo n.º 5
0
def configure_orm(disable_connection_pool=False):
    """Configure ORM using SQLAlchemy"""
    log.debug("Setting up DB connection pool (PID %s)", os.getpid())
    global engine
    global Session
    engine_args = prepare_engine_args(disable_connection_pool)

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use us.
    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')

    if conf.has_option('core', 'sql_alchemy_connect_args'):
        connect_args = conf.getimport('core', 'sql_alchemy_connect_args')
    else:
        connect_args = {}

    engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args)
    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=engine,
            expire_on_commit=False,
        )
    )
Ejemplo n.º 6
0
    def get_statsd_logger(cls):
        """Returns logger for statsd"""
        # no need to check for the scheduler/statsd_on -> this method is only called when it is set
        # and previously it would crash with None is callable if it was called without it.
        from statsd import StatsClient

        stats_class = conf.getimport('metrics', 'statsd_custom_client_path', fallback=None)

        if stats_class:
            if not issubclass(stats_class, StatsClient):
                raise AirflowConfigException(
                    "Your custom Statsd client must extend the statsd.StatsClient in order to ensure "
                    "backwards compatibility."
                )
            else:
                log.info("Successfully loaded custom Statsd client")

        else:
            stats_class = StatsClient

        statsd = stats_class(
            host=conf.get('metrics', 'statsd_host'),
            port=conf.getint('metrics', 'statsd_port'),
            prefix=conf.get('metrics', 'statsd_prefix'),
        )
        allow_list_validator = AllowListValidator(conf.get('metrics', 'statsd_allow_list', fallback=None))
        return SafeStatsdLogger(statsd, allow_list_validator)
Ejemplo n.º 7
0
def send_email(to,
               subject,
               html_content,
               files=None,
               dryrun=False,
               cc=None,
               bcc=None,
               mime_subtype='mixed',
               mime_charset='utf-8',
               **kwargs):
    """
    Send email using backend specified in EMAIL_BACKEND.
    """
    backend = conf.getimport('email', 'EMAIL_BACKEND')
    to = get_email_address_list(to)
    to = ", ".join(to)

    return backend(to,
                   subject,
                   html_content,
                   files=files,
                   dryrun=dryrun,
                   cc=cc,
                   bcc=bcc,
                   mime_subtype=mime_subtype,
                   mime_charset=mime_charset,
                   **kwargs)
Ejemplo n.º 8
0
def get_hostname():
    """
    Fetch the hostname using the callable from the config or using
    `socket.getfqdn` as a fallback.
    """
    return conf.getimport('core',
                          'hostname_callable',
                          fallback='socket.getfqdn')()
Ejemplo n.º 9
0
def resolve_xcom_backend():
    """Resolves custom XCom class"""
    clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
    if clazz:
        if not issubclass(clazz, BaseXCom):
            raise TypeError(
                f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
            )
        return clazz
    return BaseXCom
Ejemplo n.º 10
0
def resolve_session_factory() -> Type[BaseSessionFactory]:
    """Resolves custom SessionFactory class"""
    clazz = conf.getimport("aws", "session_factory", fallback=None)
    if not clazz:
        return BaseSessionFactory
    if not issubclass(clazz, BaseSessionFactory):
        raise TypeError(
            f"Your custom AWS SessionFactory class `{clazz.__name__}` is not a subclass "
            f"of `{BaseSessionFactory.__name__}`.")
    return clazz
Ejemplo n.º 11
0
def configure_orm(disable_connection_pool=False):
    """Configure ORM using SQLAlchemy"""
    from airflow.utils.log.secrets_masker import mask_secret

    log.debug("Setting up DB connection pool (PID %s)", os.getpid())
    global engine
    global Session
    engine_args = prepare_engine_args(disable_connection_pool)

    if conf.has_option('database', 'sql_alchemy_connect_args'):
        connect_args = conf.getimport('database', 'sql_alchemy_connect_args')
    else:
        connect_args = {}

    engine = create_engine(SQL_ALCHEMY_CONN,
                           connect_args=connect_args,
                           **engine_args)

    mask_secret(engine.url.password)

    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=engine,
            expire_on_commit=False,
        ))
    if engine.dialect.name == 'mssql':
        session = Session()
        try:
            result = session.execute(
                sqlalchemy.text(
                    'SELECT is_read_committed_snapshot_on FROM sys.databases WHERE name=:database_name'
                ),
                params={"database_name": engine.url.database},
            )
            data = result.fetchone()[0]
            if data != 1:
                log.critical(
                    "MSSQL database MUST have READ_COMMITTED_SNAPSHOT enabled."
                )
                log.critical("The database %s has it disabled.",
                             engine.url.database)
                log.critical(
                    "This will cause random deadlocks, Refusing to start.")
                log.critical(
                    "See https://airflow.apache.org/docs/apache-airflow/stable/howto/"
                    "set-up-database.html#setting-up-a-mssql-database")
                raise Exception(
                    "MSSQL database MUST have READ_COMMITTED_SNAPSHOT enabled."
                )
        finally:
            session.close()
def resolve_xcom_backend():
    """Resolves custom XCom class"""
    clazz = conf.getimport("core", "xcom_backend", fallback="airflow.models.xcom.{}"
                           .format(BaseXCom.__name__))
    if clazz:
        if not issubclass(clazz, BaseXCom):
            raise TypeError(
                "Your custom XCom class `{class_name}` is not a subclass of `{base_name}`."
                .format(class_name=clazz.__name__, base_name=BaseXCom.__name__)
            )
        return clazz
    return BaseXCom
Ejemplo n.º 13
0
def get_backend() -> Optional[LineageBackend]:
    """Gets the lineage backend if defined in the configs"""
    clazz = conf.getimport("lineage", "backend", fallback=None)

    if clazz:
        if not issubclass(clazz, LineageBackend):
            raise TypeError(
                f"Your custom Lineage class `{clazz.__name__}` "
                f"is not a subclass of `{LineageBackend.__name__}`.")
        else:
            return clazz()

    return None
Ejemplo n.º 14
0
def resolve_xcom_backend() -> Type[BaseXCom]:
    """Resolves custom XCom class"""
    clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
    if not clazz:
        return BaseXCom
    if not issubclass(clazz, BaseXCom):
        raise TypeError(
            f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
        )
    base_xcom_params = _get_function_params(BaseXCom.serialize_value)
    xcom_params = _get_function_params(clazz.serialize_value)
    if not set(base_xcom_params) == set(xcom_params):
        _patch_outdated_serializer(clazz=clazz, params=xcom_params)
    return clazz
Ejemplo n.º 15
0
def get_custom_secret_backend() -> Optional[BaseSecretsBackend]:
    """Get Secret Backend if defined in airflow.cfg"""
    secrets_backend_cls = conf.getimport(section='secrets', key='backend')

    if secrets_backend_cls:
        try:
            alternative_secrets_config_dict = json.loads(
                conf.get(section=CONFIG_SECTION,
                         key='backend_kwargs',
                         fallback='{}'))
        except JSONDecodeError:
            alternative_secrets_config_dict = {}

        return secrets_backend_cls(**alternative_secrets_config_dict)
    return None
Ejemplo n.º 16
0
        def __init__(self):
            """Initialize the Sentry SDK."""
            ignore_logger("airflow.task")
            ignore_logger("airflow.jobs.backfill_job.BackfillJob")
            executor_name = conf.get("core", "EXECUTOR")

            sentry_flask = FlaskIntegration()

            # LoggingIntegration is set by default.
            integrations = [sentry_flask]

            if executor_name == "CeleryExecutor":
                from sentry_sdk.integrations.celery import CeleryIntegration

                sentry_celery = CeleryIntegration()
                integrations.append(sentry_celery)

            dsn = None
            sentry_config_opts = conf.getsection("sentry") or {}
            if sentry_config_opts:
                sentry_config_opts.pop("sentry_on")
                old_way_dsn = sentry_config_opts.pop("sentry_dsn", None)
                new_way_dsn = sentry_config_opts.pop("dsn", None)
                # supported backward compatibility with old way dsn option
                dsn = old_way_dsn or new_way_dsn

                unsupported_options = self.UNSUPPORTED_SENTRY_OPTIONS.intersection(
                    sentry_config_opts.keys())
                if unsupported_options:
                    log.warning(
                        "There are unsupported options in [sentry] section: %s",
                        ", ".join(unsupported_options),
                    )

                sentry_config_opts['before_send'] = conf.getimport(
                    'sentry', 'before_send', fallback=None)

            if dsn:
                sentry_sdk.init(dsn=dsn,
                                integrations=integrations,
                                **sentry_config_opts)
            else:
                # Setting up Sentry using environment variables.
                log.debug("Defaulting to SENTRY_DSN in environment.")
                sentry_sdk.init(integrations=integrations,
                                **sentry_config_opts)
Ejemplo n.º 17
0
    def get_statsd_logger(self):
        if conf.getboolean('scheduler', 'statsd_on'):
            from statsd import StatsClient

            if conf.has_option('scheduler', 'statsd_custom_client_path'):
                stats_class = conf.getimport('scheduler', 'statsd_custom_client_path')

                if not issubclass(stats_class, StatsClient):
                    raise AirflowConfigException(
                        "Your custom Statsd client must extend the statsd.StatsClient in order to ensure "
                        "backwards compatibility."
                    )
                else:
                    log.info("Successfully loaded custom Statsd client")

            else:
                stats_class = StatsClient

        statsd = stats_class(
            host=conf.get('scheduler', 'statsd_host'),
            port=conf.getint('scheduler', 'statsd_port'),
            prefix=conf.get('scheduler', 'statsd_prefix'))
        allow_list_validator = AllowListValidator(conf.get('scheduler', 'statsd_allow_list', fallback=None))
        return SafeStatsdLogger(statsd, allow_list_validator)
Ejemplo n.º 18
0
class SerializedBaseOperator(BaseOperator, BaseSerialization):
    """A JSON serializable representation of operator.

    All operators are casted to SerializedBaseOperator after deserialization.
    Class specific attributes used by UI are move to object attributes.
    """

    _decorated_fields = {'executor_config'}

    _CONSTRUCTOR_PARAMS = {
        k: v.default
        for k, v in signature(BaseOperator.__init__).parameters.items()
        if v.default is not v.empty
    }

    dependency_detector = conf.getimport('scheduler', 'dependency_detector')

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # task_type is used by UI to display the correct class type, because UI only
        # receives BaseOperator from deserialized DAGs.
        self._task_type = 'BaseOperator'
        # Move class attributes into object attributes.
        self.ui_color = BaseOperator.ui_color
        self.ui_fgcolor = BaseOperator.ui_fgcolor
        self.template_ext = BaseOperator.template_ext
        self.template_fields = BaseOperator.template_fields
        self.operator_extra_links = BaseOperator.operator_extra_links

    @property
    def task_type(self) -> str:
        # Overwrites task_type of BaseOperator to use _task_type instead of
        # __class__.__name__.
        return self._task_type

    @task_type.setter
    def task_type(self, task_type: str):
        self._task_type = task_type

    @classmethod
    def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
        serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator))

        # Simplify partial_kwargs by comparing it to the most barebone object.
        # Remove all entries that are simply default values.
        serialized_partial = serialized_op["partial_kwargs"]
        for k, default in _get_default_mapped_partial().items():
            try:
                v = serialized_partial[k]
            except KeyError:
                continue
            if v == default:
                del serialized_partial[k]

        # Simplify op_kwargs format. It must be a dict, so we flatten it.
        with contextlib.suppress(KeyError):
            op_kwargs = serialized_op["mapped_kwargs"]["op_kwargs"]
            assert op_kwargs[Encoding.TYPE] == DAT.DICT
            serialized_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
        with contextlib.suppress(KeyError):
            op_kwargs = serialized_op["partial_kwargs"]["op_kwargs"]
            assert op_kwargs[Encoding.TYPE] == DAT.DICT
            serialized_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
        with contextlib.suppress(KeyError):
            op_kwargs = serialized_op["mapped_op_kwargs"]
            assert op_kwargs[Encoding.TYPE] == DAT.DICT
            serialized_op["mapped_op_kwargs"] = op_kwargs[Encoding.VAR]

        serialized_op["_is_mapped"] = True
        return serialized_op

    @classmethod
    def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]:
        return cls._serialize_node(op, include_deps=op.deps is not BaseOperator.deps)

    @classmethod
    def _serialize_node(cls, op: Union[BaseOperator, MappedOperator], include_deps: bool) -> Dict[str, Any]:
        """Serializes operator into a JSON object."""
        serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
        serialize_op['_task_type'] = getattr(op, "_task_type", type(op).__name__)
        serialize_op['_task_module'] = getattr(op, "_task_module", type(op).__module__)

        # Used to determine if an Operator is inherited from EmptyOperator
        serialize_op['_is_empty'] = op.inherits_from_empty_operator

        if op.operator_extra_links:
            serialize_op['_operator_extra_links'] = cls._serialize_operator_extra_links(
                op.operator_extra_links
            )

        if include_deps:
            serialize_op['deps'] = cls._serialize_deps(op.deps)

        # Store all template_fields as they are if there are JSON Serializable
        # If not, store them as strings
        if op.template_fields:
            for template_field in op.template_fields:
                value = getattr(op, template_field, None)
                if not cls._is_excluded(value, template_field, op):
                    serialize_op[template_field] = serialize_template_field(value)

        if op.params:
            serialize_op['params'] = cls._serialize_params_dict(op.params)

        return serialize_op

    @classmethod
    def _serialize_deps(cls, op_deps: Iterable["BaseTIDep"]) -> List[str]:
        from airflow import plugins_manager

        plugins_manager.initialize_ti_deps_plugins()
        if plugins_manager.registered_ti_dep_classes is None:
            raise AirflowException("Can not load plugins")

        deps = []
        for dep in op_deps:
            klass = type(dep)
            module_name = klass.__module__
            qualname = f'{module_name}.{klass.__name__}'
            if (
                not qualname.startswith("airflow.ti_deps.deps.")
                and qualname not in plugins_manager.registered_ti_dep_classes
            ):
                raise SerializationError(
                    f"Custom dep class {qualname} not serialized, please register it through plugins."
                )
            deps.append(qualname)
        # deps needs to be sorted here, because op_deps is a set, which is unstable when traversing,
        # and the same call may get different results.
        # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
        return sorted(deps)

    @classmethod
    def populate_operator(cls, op: Operator, encoded_op: Dict[str, Any]) -> None:
        if "label" not in encoded_op:
            # Handle deserialization of old data before the introduction of TaskGroup
            encoded_op["label"] = encoded_op["task_id"]

        # Extra Operator Links defined in Plugins
        op_extra_links_from_plugin = {}

        # We don't want to load Extra Operator links in Scheduler
        if cls._load_operator_extra_links:
            from airflow import plugins_manager

            plugins_manager.initialize_extra_operators_links_plugins()

            if plugins_manager.operator_extra_links is None:
                raise AirflowException("Can not load plugins")

            for ope in plugins_manager.operator_extra_links:
                for operator in ope.operators:
                    if (
                        operator.__name__ == encoded_op["_task_type"]
                        and operator.__module__ == encoded_op["_task_module"]
                    ):
                        op_extra_links_from_plugin.update({ope.name: ope})

            # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
            # set the Operator links attribute
            # The case for "If OperatorLinks are defined in the operator that is being Serialized"
            # is handled in the deserialization loop where it matches k == "_operator_extra_links"
            if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
                setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values()))

        for k, v in encoded_op.items():
            # Todo: TODO: Remove in Airflow 3.0 when dummy operator is removed
            if k == "_is_dummy":
                k = "_is_empty"
            if k == "_downstream_task_ids":
                # Upgrade from old format/name
                k = "downstream_task_ids"
            if k == "label":
                # Label shouldn't be set anymore --  it's computed from task_id now
                continue
            elif k == "downstream_task_ids":
                v = set(v)
            elif k == "subdag":
                v = SerializedDAG.deserialize_dag(v)
            elif k in {"retry_delay", "execution_timeout", "sla", "max_retry_delay"}:
                v = cls._deserialize_timedelta(v)
            elif k in encoded_op["template_fields"]:
                pass
            elif k == "resources":
                v = Resources.from_dict(v)
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "_operator_extra_links":
                if cls._load_operator_extra_links:
                    op_predefined_extra_links = cls._deserialize_operator_extra_links(v)

                    # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
                    op_predefined_extra_links.update(op_extra_links_from_plugin)
                else:
                    op_predefined_extra_links = {}

                v = list(op_predefined_extra_links.values())
                k = "operator_extra_links"

            elif k == "deps":
                v = cls._deserialize_deps(v)
            elif k == "params":
                v = cls._deserialize_params_dict(v)
            elif k in ("mapped_kwargs", "partial_kwargs"):
                if "op_kwargs" not in v:
                    op_kwargs: Optional[dict] = None
                else:
                    op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
                v = {arg: cls._deserialize(value) for arg, value in v.items()}
                if op_kwargs is not None:
                    v["op_kwargs"] = op_kwargs
            elif k == "mapped_op_kwargs":
                v = {arg: cls._deserialize(value) for arg, value in v.items()}
            elif k in cls._decorated_fields or k not in op.get_serialized_fields():
                v = cls._deserialize(v)
            # else use v as it is

            setattr(op, k, v)

        for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
            # TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
            # could go away.
            if not hasattr(op, k):
                setattr(op, k, None)

        # Set all the template_field to None that were not present in Serialized JSON
        for field in op.template_fields:
            if not hasattr(op, field):
                setattr(op, field, None)

        # Used to determine if an Operator is inherited from EmptyOperator
        setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))

    @classmethod
    def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Operator:
        """Deserializes an operator from a JSON object."""
        op: Operator
        if encoded_op.get("_is_mapped", False):
            # Most of these will be loaded later, these are just some stand-ins.
            op_data = {k: v for k, v in encoded_op.items() if k in BaseOperator.get_serialized_fields()}
            op = MappedOperator(
                operator_class=op_data,
                mapped_kwargs={},
                partial_kwargs={},
                task_id=encoded_op["task_id"],
                params={},
                deps=MappedOperator.deps_for(BaseOperator),
                operator_extra_links=BaseOperator.operator_extra_links,
                template_ext=BaseOperator.template_ext,
                template_fields=BaseOperator.template_fields,
                template_fields_renderers=BaseOperator.template_fields_renderers,
                ui_color=BaseOperator.ui_color,
                ui_fgcolor=BaseOperator.ui_fgcolor,
                is_empty=False,
                task_module=encoded_op["_task_module"],
                task_type=encoded_op["_task_type"],
                dag=None,
                task_group=None,
                start_date=None,
                end_date=None,
                expansion_kwargs_attr=encoded_op["_expansion_kwargs_attr"],
            )
        else:
            op = SerializedBaseOperator(task_id=encoded_op['task_id'])

        cls.populate_operator(op, encoded_op)
        return op

    @classmethod
    def detect_dependencies(cls, op: Operator) -> Optional['DagDependency']:
        """Detects between DAG dependencies for the operator."""
        return cls.dependency_detector.detect_task_dependencies(op)

    @classmethod
    def _is_excluded(cls, var: Any, attrname: str, op: "DAGNode"):
        if var is not None and op.has_dag() and attrname.endswith("_date"):
            # If this date is the same as the matching field in the dag, then
            # don't store it again at the task level.
            dag_date = getattr(op.dag, attrname, None)
            if var is dag_date or var == dag_date:
                return True
        return super()._is_excluded(var, attrname, op)

    @classmethod
    def _deserialize_deps(cls, deps: List[str]) -> Set["BaseTIDep"]:
        from airflow import plugins_manager

        plugins_manager.initialize_ti_deps_plugins()
        if plugins_manager.registered_ti_dep_classes is None:
            raise AirflowException("Can not load plugins")

        instances = set()
        for qualname in set(deps):
            if (
                not qualname.startswith("airflow.ti_deps.deps.")
                and qualname not in plugins_manager.registered_ti_dep_classes
            ):
                raise SerializationError(
                    f"Custom dep class {qualname} not deserialized, please register it through plugins."
                )

            try:
                instances.add(import_string(qualname)())
            except ImportError:
                log.warning("Error importing dep %r", qualname, exc_info=True)
        return instances

    @classmethod
    def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]:
        """
        Deserialize Operator Links if the Classes are registered in Airflow Plugins.
        Error is raised if the OperatorLink is not found in Plugins too.

        :param encoded_op_links: Serialized Operator Link
        :return: De-Serialized Operator Link
        """
        from airflow import plugins_manager

        plugins_manager.initialize_extra_operators_links_plugins()

        if plugins_manager.registered_operator_link_classes is None:
            raise AirflowException("Can't load plugins")
        op_predefined_extra_links = {}

        for _operator_links_source in encoded_op_links:
            # Get the key, value pair as Tuple where key is OperatorLink ClassName
            # and value is the dictionary containing the arguments passed to the OperatorLink
            #
            # Example of a single iteration:
            #
            #   _operator_links_source =
            #   {
            #       'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
            #           'index': 0
            #       }
            #   },
            #
            #   list(_operator_links_source.items()) =
            #   [
            #       (
            #           'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
            #           {'index': 0}
            #       )
            #   ]
            #
            #   list(_operator_links_source.items())[0] =
            #   (
            #       'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
            #       {
            #           'index': 0
            #       }
            #   )

            _operator_link_class_path, data = list(_operator_links_source.items())[0]
            if _operator_link_class_path in get_operator_extra_links():
                single_op_link_class = import_string(_operator_link_class_path)
            elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
                single_op_link_class = plugins_manager.registered_operator_link_classes[
                    _operator_link_class_path
                ]
            else:
                log.error("Operator Link class %r not registered", _operator_link_class_path)
                return {}

            op_predefined_extra_link: BaseOperatorLink = cattr.structure(data, single_op_link_class)

            op_predefined_extra_links.update({op_predefined_extra_link.name: op_predefined_extra_link})

        return op_predefined_extra_links

    @classmethod
    def _serialize_operator_extra_links(cls, operator_extra_links: Iterable[BaseOperatorLink]):
        """
        Serialize Operator Links. Store the import path of the OperatorLink and the arguments
        passed to it. Example
        ``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``

        :param operator_extra_links: Operator Link
        :return: Serialized Operator Link
        """
        serialize_operator_extra_links = []
        for operator_extra_link in operator_extra_links:
            op_link_arguments = cattr.unstructure(operator_extra_link)
            if not isinstance(op_link_arguments, dict):
                op_link_arguments = {}

            module_path = (
                f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
            )
            serialize_operator_extra_links.append({module_path: op_link_arguments})

        return serialize_operator_extra_links
Ejemplo n.º 19
0
def configure_orm(disable_connection_pool=False):
    """ Configure ORM using SQLAlchemy"""
    log.debug("Setting up DB connection pool (PID %s)", os.getpid())
    global engine
    global Session
    engine_args = {}

    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.configure_orm(): Using NullPool")
    elif 'sqlite' not in SQL_ALCHEMY_CONN:
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of “sleeping” connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('core',
                                   'SQL_ALCHEMY_MAX_OVERFLOW',
                                   fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('core',
                                   'SQL_ALCHEMY_POOL_RECYCLE',
                                   fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like “SELECT 1”, but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('core',
                                        'SQL_ALCHEMY_POOL_PRE_PING',
                                        fallback=True)

        log.debug(
            "settings.configure_orm(): Using pool settings. pool_size=%d, max_overflow=%d, "
            "pool_recycle=%d, pid=%d", pool_size, max_overflow, pool_recycle,
            os.getpid())
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use us.
    engine_args['encoding'] = conf.get('core',
                                       'SQL_ENGINE_ENCODING',
                                       fallback='utf-8')

    if conf.has_option('core', 'sql_alchemy_connect_args'):
        connect_args = conf.getimport('core', 'sql_alchemy_connect_args')
    else:
        connect_args = {}

    engine = create_engine(SQL_ALCHEMY_CONN,
                           connect_args=connect_args,
                           **engine_args)
    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(autocommit=False,
                     autoflush=False,
                     bind=engine,
                     expire_on_commit=False))
Ejemplo n.º 20
0
class SerializedBaseOperator(BaseOperator, BaseSerialization):
    """A JSON serializable representation of operator.

    All operators are casted to SerializedBaseOperator after deserialization.
    Class specific attributes used by UI are move to object attributes.
    """

    _decorated_fields = {'executor_config'}

    _CONSTRUCTOR_PARAMS = {
        k: v.default
        for k, v in signature(BaseOperator.__init__).parameters.items()
        if v.default is not v.empty
    }

    dependency_detector = conf.getimport('scheduler', 'dependency_detector')

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # task_type is used by UI to display the correct class type, because UI only
        # receives BaseOperator from deserialized DAGs.
        self._task_type = 'BaseOperator'
        # Move class attributes into object attributes.
        self.ui_color = BaseOperator.ui_color
        self.ui_fgcolor = BaseOperator.ui_fgcolor
        self.template_ext = BaseOperator.template_ext
        self.template_fields = BaseOperator.template_fields
        self.operator_extra_links = BaseOperator.operator_extra_links

    @property
    def task_type(self) -> str:
        # Overwrites task_type of BaseOperator to use _task_type instead of
        # __class__.__name__.
        return self._task_type

    @task_type.setter
    def task_type(self, task_type: str):
        self._task_type = task_type

    @classmethod
    def serialize_operator(cls, op: BaseOperator) -> Dict[str, Any]:
        """Serializes operator into a JSON object."""
        serialize_op = cls.serialize_to_json(op, cls._decorated_fields)
        serialize_op['_task_type'] = op.__class__.__name__
        serialize_op['_task_module'] = op.__class__.__module__

        # Used to determine if an Operator is inherited from DummyOperator
        serialize_op['_is_dummy'] = op.inherits_from_dummy_operator

        if op.operator_extra_links:
            serialize_op[
                '_operator_extra_links'] = cls._serialize_operator_extra_links(
                    op.operator_extra_links)

        if op.deps is not BaseOperator.deps:
            # Are the deps different to BaseOperator, if so serialize the class names!
            # For Airflow 2.0 expediency we _only_ allow built in Dep classes.
            # Fix this for 2.0.x or 2.1
            deps = []
            for dep in op.deps:
                klass = type(dep)
                module_name = klass.__module__
                if not module_name.startswith("airflow.ti_deps.deps."):
                    raise SerializationError(
                        f"Cannot serialize {(op.dag.dag_id + '.' + op.task_id)!r} with `deps` from non-core "
                        f"module {module_name!r}")

                deps.append(f'{module_name}.{klass.__name__}')
            # deps needs to be sorted here, because op.deps is a set, which is unstable when traversing,
            # and the same call may get different results.
            # When calling json.dumps(self.data, sort_keys=True) to generate dag_hash, misjudgment will occur
            serialize_op['deps'] = sorted(deps)

        # Store all template_fields as they are if there are JSON Serializable
        # If not, store them as strings
        if op.template_fields:
            for template_field in op.template_fields:
                value = getattr(op, template_field, None)
                if not cls._is_excluded(value, template_field, op):
                    serialize_op[template_field] = serialize_template_field(
                        value)

        if op.params:
            serialize_op['params'] = cls._serialize_params_dict(op.params)

        return serialize_op

    @classmethod
    def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator:
        """Deserializes an operator from a JSON object."""
        op = SerializedBaseOperator(task_id=encoded_op['task_id'])

        if "label" not in encoded_op:
            # Handle deserialization of old data before the introduction of TaskGroup
            encoded_op["label"] = encoded_op["task_id"]

        # Extra Operator Links defined in Plugins
        op_extra_links_from_plugin = {}

        # We don't want to load Extra Operator links in Scheduler
        if cls._load_operator_extra_links:
            from airflow import plugins_manager

            plugins_manager.initialize_extra_operators_links_plugins()

            if plugins_manager.operator_extra_links is None:
                raise AirflowException("Can not load plugins")

            for ope in plugins_manager.operator_extra_links:
                for operator in ope.operators:
                    if (operator.__name__ == encoded_op["_task_type"] and
                            operator.__module__ == encoded_op["_task_module"]):
                        op_extra_links_from_plugin.update({ope.name: ope})

            # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized
            # set the Operator links attribute
            # The case for "If OperatorLinks are defined in the operator that is being Serialized"
            # is handled in the deserialization loop where it matches k == "_operator_extra_links"
            if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op:
                setattr(op, "operator_extra_links",
                        list(op_extra_links_from_plugin.values()))

        for k, v in encoded_op.items():

            if k == "_downstream_task_ids":
                v = set(v)
            elif k == "subdag":
                v = SerializedDAG.deserialize_dag(v)
            elif k in {
                    "retry_delay", "execution_timeout", "sla",
                    "max_retry_delay"
            }:
                v = cls._deserialize_timedelta(v)
            elif k in encoded_op["template_fields"]:
                pass
            elif k.endswith("_date"):
                v = cls._deserialize_datetime(v)
            elif k == "_operator_extra_links":
                if cls._load_operator_extra_links:
                    op_predefined_extra_links = cls._deserialize_operator_extra_links(
                        v)

                    # If OperatorLinks with the same name exists, Links via Plugin have higher precedence
                    op_predefined_extra_links.update(
                        op_extra_links_from_plugin)
                else:
                    op_predefined_extra_links = {}

                v = list(op_predefined_extra_links.values())
                k = "operator_extra_links"

            elif k == "deps":
                v = cls._deserialize_deps(v)
            elif k == "params":
                v = cls._deserialize_params_dict(v)
            elif k in cls._decorated_fields or k not in op.get_serialized_fields(
            ):
                v = cls._deserialize(v)
            # else use v as it is

            setattr(op, k, v)

        for k in op.get_serialized_fields() - encoded_op.keys(
        ) - cls._CONSTRUCTOR_PARAMS.keys():
            setattr(op, k, None)

        # Set all the template_field to None that were not present in Serialized JSON
        for field in op.template_fields:
            if not hasattr(op, field):
                setattr(op, field, None)

        # Used to determine if an Operator is inherited from DummyOperator
        setattr(op, "_is_dummy", bool(encoded_op.get("_is_dummy", False)))

        return op

    @classmethod
    def detect_dependencies(cls,
                            op: BaseOperator) -> Optional['DagDependency']:
        """Detects between DAG dependencies for the operator."""
        return cls.dependency_detector.detect_task_dependencies(op)

    @classmethod
    def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator):
        if var is not None and op.has_dag() and attrname.endswith("_date"):
            # If this date is the same as the matching field in the dag, then
            # don't store it again at the task level.
            dag_date = getattr(op.dag, attrname, None)
            if var is dag_date or var == dag_date:
                return True
        return super()._is_excluded(var, attrname, op)

    @classmethod
    def _deserialize_deps(cls, deps: List[str]) -> Set["BaseTIDep"]:
        instances = set()
        for qualname in set(deps):
            if not qualname.startswith("airflow.ti_deps.deps."):
                log.error("Dep class %r not registered", qualname)
                continue

            try:
                instances.add(import_string(qualname)())
            except ImportError:
                log.warning("Error importing dep %r", qualname, exc_info=True)
        return instances

    @classmethod
    def _deserialize_operator_extra_links(
            cls, encoded_op_links: list) -> Dict[str, BaseOperatorLink]:
        """
        Deserialize Operator Links if the Classes  are registered in Airflow Plugins.
        Error is raised if the OperatorLink is not found in Plugins too.

        :param encoded_op_links: Serialized Operator Link
        :return: De-Serialized Operator Link
        """
        from airflow import plugins_manager

        plugins_manager.initialize_extra_operators_links_plugins()

        if plugins_manager.registered_operator_link_classes is None:
            raise AirflowException("Can't load plugins")
        op_predefined_extra_links = {}

        for _operator_links_source in encoded_op_links:
            # Get the key, value pair as Tuple where key is OperatorLink ClassName
            # and value is the dictionary containing the arguments passed to the OperatorLink
            #
            # Example of a single iteration:
            #
            #   _operator_links_source =
            #   {
            #       'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': {
            #           'index': 0
            #       }
            #   },
            #
            #   list(_operator_links_source.items()) =
            #   [
            #       (
            #           'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
            #           {'index': 0}
            #       )
            #   ]
            #
            #   list(_operator_links_source.items())[0] =
            #   (
            #       'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
            #       {
            #           'index': 0
            #       }
            #   )

            _operator_link_class_path, data = list(
                _operator_links_source.items())[0]
            if _operator_link_class_path in get_operator_extra_links():
                single_op_link_class = import_string(_operator_link_class_path)
            elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
                single_op_link_class = plugins_manager.registered_operator_link_classes[
                    _operator_link_class_path]
            else:
                log.error("Operator Link class %r not registered",
                          _operator_link_class_path)
                return {}

            op_predefined_extra_link: BaseOperatorLink = cattr.structure(
                data, single_op_link_class)

            op_predefined_extra_links.update(
                {op_predefined_extra_link.name: op_predefined_extra_link})

        return op_predefined_extra_links

    @classmethod
    def _serialize_operator_extra_links(
            cls, operator_extra_links: Iterable[BaseOperatorLink]):
        """
        Serialize Operator Links. Store the import path of the OperatorLink and the arguments
        passed to it. Example
        ``[{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]``

        :param operator_extra_links: Operator Link
        :return: Serialized Operator Link
        """
        serialize_operator_extra_links = []
        for operator_extra_link in operator_extra_links:
            op_link_arguments = cattr.unstructure(operator_extra_link)
            if not isinstance(op_link_arguments, dict):
                op_link_arguments = {}

            module_path = (
                f"{operator_extra_link.__class__.__module__}.{operator_extra_link.__class__.__name__}"
            )
            serialize_operator_extra_links.append(
                {module_path: op_link_arguments})

        return serialize_operator_extra_links
Ejemplo n.º 21
0
def get_current_handler_stat_name_func() -> Callable[[str], str]:
    """Get Stat Name Handler from airflow.cfg"""
    return conf.getimport('metrics', 'stat_name_handler') or stat_name_default_handler
Ejemplo n.º 22
0
log = logging.getLogger(__name__)

# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'

CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'

OPERATION_TIMEOUT = conf.getint('celery', 'operation_timeout', fallback=2)
'''
To start the celery worker, run the command:
airflow celery worker
'''

if conf.has_option('celery', 'celery_config_options'):
    celery_configuration = conf.getimport('celery', 'celery_config_options')
else:
    celery_configuration = DEFAULT_CELERY_CONFIG

app = Celery(conf.get('celery', 'CELERY_APP_NAME'),
             config_source=celery_configuration)


@app.task
def execute_command(command_to_exec: CommandType) -> None:
    """Executes command."""
    BaseExecutor.validate_command(command_to_exec)
    log.info("Executing command in Celery: %s", command_to_exec)

    if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
        _execute_in_subprocees(command_to_exec)
Ejemplo n.º 23
0
def get_current_handle_stat_name_func() -> Callable[[str], str]:
    return conf.getimport('scheduler', 'stat_name_handler') or stat_name_default_handler