Exemplo n.º 1
0
    def __init__(
            self,
            task_id,  # type: str
            owner=configuration.conf.get('operators',
                                         'DEFAULT_OWNER'),  # type: str
            email=None,  # type: Optional[str]
            email_on_retry=True,  # type: bool
            email_on_failure=True,  # type: bool
            retries=0,  # type: int
            retry_delay=timedelta(seconds=300),  # type: timedelta
            retry_exponential_backoff=False,  # type: bool
            max_retry_delay=None,  # type: Optional[datetime]
            start_date=None,  # type: Optional[datetime]
            end_date=None,  # type: Optional[datetime]
            schedule_interval=None,  # not hooked as of now
            depends_on_past=False,  # type: bool
            wait_for_downstream=False,  # type: bool
            dag=None,  # type: Optional[DAG]
            params=None,  # type: Optional[Dict]
            default_args=None,  # type: Optional[Dict]
            priority_weight=1,  # type: int
            weight_rule=WeightRule.DOWNSTREAM,  # type: str
            queue=configuration.conf.get('celery',
                                         'default_queue'),  # type: str
            pool=None,  # type: Optional[str]
            sla=None,  # type: Optional[timedelta]
            execution_timeout=None,  # type: Optional[timedelta]
            on_failure_callback=None,  # type: Optional[Callable]
            on_success_callback=None,  # type: Optional[Callable]
            on_retry_callback=None,  # type: Optional[Callable]
            trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
            resources=None,  # type: Optional[Dict]
            run_as_user=None,  # type: Optional[str]
            task_concurrency=None,  # type: Optional[int]
            executor_config=None,  # type: Optional[Dict]
            do_xcom_push=True,  # type: bool
            inlets=None,  # type: Optional[Dict]
            outlets=None,  # type: Optional[Dict]
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead", self)
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback
        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(**(resources or {}))
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: List[DataSet]
        self.outlets = []  # type: List[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)

        self._comps = {
            'task_id',
            'dag_id',
            'owner',
            'email',
            'email_on_retry',
            'retry_delay',
            'retry_exponential_backoff',
            'max_retry_delay',
            'start_date',
            'schedule_interval',
            'depends_on_past',
            'wait_for_downstream',
            'priority_weight',
            'sla',
            'execution_timeout',
            'on_failure_callback',
            'on_success_callback',
            'on_retry_callback',
            'do_xcom_push',
        }
Exemplo n.º 2
0
if TYPE_CHECKING:
    import jinja2  # Slow import.

    from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.models.taskinstance import TaskInstance

DEFAULT_OWNER: str = conf.get("operators", "default_owner")
DEFAULT_POOL_SLOTS: int = 1
DEFAULT_PRIORITY_WEIGHT: int = 1
DEFAULT_QUEUE: str = conf.get("operators", "default_queue")
DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
    conf.get("core",
             "default_task_weight_rule",
             fallback=WeightRule.DOWNSTREAM))
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS


class AbstractOperator(LoggingMixin, DAGNode):
    """Common implementation for operators, including unmapped and mapped.

    This base class is more about sharing implementations, not defining a common
    interface. Unfortunately it's difficult to use this as the common base class
    for typing due to BaseOperator carrying too much historical baggage.

    The union type ``from airflow.models.operator import Operator`` is easier
    to use for typing purposes.

    :meta private:
Exemplo n.º 3
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[Union[str, Iterable[str]]] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = conf.getint('core',
                                                 'default_task_retries',
                                                 fallback=0),
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag=None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_execute_callback: Optional[Callable] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Any] = None,
            outlets: Optional[Any] = None,
            *args,
            **kwargs):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs,
                        t=task_id), )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_execute_callback = on_execute_callback
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(
            **resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        from airflow.models.dag import DAG
        self.subdag: Optional[DAG] = None

        self._log = logging.getLogger("airflow.task.operators")

        # Lineage
        self.inlets: List = []
        self.outlets: List = []

        self._inlets: List = []
        self._outlets: List = []

        if inlets:
            self._inlets = inlets if isinstance(inlets, list) else [
                inlets,
            ]

        if outlets:
            self._outlets = outlets if isinstance(outlets, list) else [
                outlets,
            ]
Exemplo n.º 4
0
 def __init__(self, **metadata):
     super().__init__(**metadata)
     self.validators = [validate.OneOf(WeightRule.all_weight_rules())
                        ] + list(self.validators)
Exemplo n.º 5
0
class AirflowConfigParser(ConfigParser):
    """Custom Airflow Configparser supporting defaults and deprecated options"""

    # These configuration elements can be fetched as the stdout of commands
    # following the "{section}__{name}__cmd" pattern, the idea behind this
    # is to not store password on boxes in text files.
    # These configs can also be fetched from Secrets backend
    # following the "{section}__{name}__secret" pattern
    sensitive_config_values = {
        ('core', 'sql_alchemy_conn'),
        ('core', 'fernet_key'),
        ('celery', 'broker_url'),
        ('celery', 'flower_basic_auth'),
        ('celery', 'result_backend'),
        ('atlas', 'password'),
        ('smtp', 'smtp_password'),
        ('webserver', 'secret_key'),
    }

    # A mapping of (new section, new option) -> (old section, old option, since_version).
    # When reading new option, the old option will be checked to see if it exists. If it does a
    # DeprecationWarning will be issued and the old option will be used instead
    deprecated_options = {
        ('celery', 'worker_precheck'): ('core', 'worker_precheck', '2.0.0'),
        ('logging', 'base_log_folder'): ('core', 'base_log_folder', '2.0.0'),
        ('logging', 'remote_logging'): ('core', 'remote_logging', '2.0.0'),
        ('logging', 'remote_log_conn_id'): ('core', 'remote_log_conn_id', '2.0.0'),
        ('logging', 'remote_base_log_folder'): ('core', 'remote_base_log_folder', '2.0.0'),
        ('logging', 'encrypt_s3_logs'): ('core', 'encrypt_s3_logs', '2.0.0'),
        ('logging', 'logging_level'): ('core', 'logging_level', '2.0.0'),
        ('logging', 'fab_logging_level'): ('core', 'fab_logging_level', '2.0.0'),
        ('logging', 'logging_config_class'): ('core', 'logging_config_class', '2.0.0'),
        ('logging', 'colored_console_log'): ('core', 'colored_console_log', '2.0.0'),
        ('logging', 'colored_log_format'): ('core', 'colored_log_format', '2.0.0'),
        ('logging', 'colored_formatter_class'): ('core', 'colored_formatter_class', '2.0.0'),
        ('logging', 'log_format'): ('core', 'log_format', '2.0.0'),
        ('logging', 'simple_log_format'): ('core', 'simple_log_format', '2.0.0'),
        ('logging', 'task_log_prefix_template'): ('core', 'task_log_prefix_template', '2.0.0'),
        ('logging', 'log_filename_template'): ('core', 'log_filename_template', '2.0.0'),
        ('logging', 'log_processor_filename_template'): ('core', 'log_processor_filename_template', '2.0.0'),
        ('logging', 'dag_processor_manager_log_location'): (
            'core',
            'dag_processor_manager_log_location',
            '2.0.0',
        ),
        ('logging', 'task_log_reader'): ('core', 'task_log_reader', '2.0.0'),
        ('metrics', 'statsd_on'): ('scheduler', 'statsd_on', '2.0.0'),
        ('metrics', 'statsd_host'): ('scheduler', 'statsd_host', '2.0.0'),
        ('metrics', 'statsd_port'): ('scheduler', 'statsd_port', '2.0.0'),
        ('metrics', 'statsd_prefix'): ('scheduler', 'statsd_prefix', '2.0.0'),
        ('metrics', 'statsd_allow_list'): ('scheduler', 'statsd_allow_list', '2.0.0'),
        ('metrics', 'stat_name_handler'): ('scheduler', 'stat_name_handler', '2.0.0'),
        ('metrics', 'statsd_datadog_enabled'): ('scheduler', 'statsd_datadog_enabled', '2.0.0'),
        ('metrics', 'statsd_datadog_tags'): ('scheduler', 'statsd_datadog_tags', '2.0.0'),
        ('metrics', 'statsd_custom_client_path'): ('scheduler', 'statsd_custom_client_path', '2.0.0'),
        ('scheduler', 'parsing_processes'): ('scheduler', 'max_threads', '1.10.14'),
        ('scheduler', 'scheduler_idle_sleep_time'): ('scheduler', 'processor_poll_interval', '2.2.0'),
        ('operators', 'default_queue'): ('celery', 'default_queue', '2.1.0'),
        ('core', 'hide_sensitive_var_conn_fields'): ('admin', 'hide_sensitive_variable_fields', '2.1.0'),
        ('core', 'sensitive_var_conn_names'): ('admin', 'sensitive_variable_fields', '2.1.0'),
        ('core', 'default_pool_task_slot_count'): ('core', 'non_pooled_task_slot_count', '1.10.4'),
        ('core', 'max_active_tasks_per_dag'): ('core', 'dag_concurrency', '2.2.0'),
        ('logging', 'worker_log_server_port'): ('celery', 'worker_log_server_port', '2.2.0'),
        ('api', 'access_control_allow_origins'): ('api', 'access_control_allow_origin', '2.2.0'),
        ('api', 'auth_backends'): ('api', 'auth_backend', '2.3'),
    }

    # A mapping of old default values that we want to change and warn the user
    # about. Mapping of section -> setting -> { old, replace, by_version }
    deprecated_values = {
        'core': {
            'hostname_callable': (re.compile(r':'), r'.', '2.1'),
        },
        'webserver': {
            'navbar_color': (re.compile(r'\A#007A87\Z', re.IGNORECASE), '#fff', '2.1'),
        },
        'email': {
            'email_backend': (
                re.compile(r'^airflow\.contrib\.utils\.sendgrid\.send_email$'),
                r'airflow.providers.sendgrid.utils.emailer.send_email',
                '2.1',
            ),
        },
        'logging': {
            'log_filename_template': (
                re.compile(re.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")),
                "XX-set-after-default-config-loaded-XX",
                3.0,
            ),
        },
        'api': {
            'auth_backends': (
                re.compile(r'^airflow\.api\.auth\.backend\.deny_all$|^$'),
                'airflow.api.auth.backend.session',
                '3.0',
            ),
        },
        'elasticsearch': {
            'log_id_template': (
                re.compile('^' + re.escape('{dag_id}-{task_id}-{run_id}-{try_number}') + '$'),
                '{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}',
                3.0,
            )
        },
    }

    _available_logging_levels = ['CRITICAL', 'FATAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG']
    enums_options = {
        ("core", "default_task_weight_rule"): sorted(WeightRule.all_weight_rules()),
        ('core', 'mp_start_method'): multiprocessing.get_all_start_methods(),
        ("scheduler", "file_parsing_sort_mode"): ["modified_time", "random_seeded_by_host", "alphabetical"],
        ("logging", "logging_level"): _available_logging_levels,
        ("logging", "fab_logging_level"): _available_logging_levels,
        # celery_logging_level can be empty, which uses logging_level as fallback
        ("logging", "celery_logging_level"): _available_logging_levels + [''],
    }

    upgraded_values: Dict[Tuple[str, str], str]
    """Mapping of (section,option) to the old value that was upgraded"""

    # This method transforms option names on every read, get, or set operation.
    # This changes from the default behaviour of ConfigParser from lowercasing
    # to instead be case-preserving
    def optionxform(self, optionstr: str) -> str:
        return optionstr

    def __init__(self, default_config=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.upgraded_values = {}

        self.airflow_defaults = ConfigParser(*args, **kwargs)
        if default_config is not None:
            self.airflow_defaults.read_string(default_config)
            # Set the upgrade value based on the current loaded default
            default = self.airflow_defaults.get('logging', 'log_filename_template', fallback=None, raw=True)
            if default:
                replacement = self.deprecated_values['logging']['log_filename_template']
                self.deprecated_values['logging']['log_filename_template'] = (
                    replacement[0],
                    default,
                    replacement[2],
                )
            else:
                # In case of tests it might not exist
                with suppress(KeyError):
                    del self.deprecated_values['logging']['log_filename_template']
        else:
            with suppress(KeyError):
                del self.deprecated_values['logging']['log_filename_template']

        self.is_validated = False

    def validate(self):

        self._validate_config_dependencies()

        self._validate_enums()

        for section, replacement in self.deprecated_values.items():
            for name, info in replacement.items():
                old, new, version = info
                current_value = self.get(section, name, fallback="")
                if self._using_old_value(old, current_value):
                    self.upgraded_values[(section, name)] = current_value
                    new_value = old.sub(new, current_value)
                    self._update_env_var(section=section, name=name, new_value=new_value)
                    self._create_future_warning(
                        name=name,
                        section=section,
                        current_value=current_value,
                        new_value=new_value,
                        version=version,
                    )

        self._upgrade_auth_backends()
        self.is_validated = True

    def _upgrade_auth_backends(self):
        """
        Ensure a custom auth_backends setting contains session,
        which is needed by the UI for ajax queries.
        """
        old_value = self.get("api", "auth_backends", fallback="")
        if old_value in ('airflow.api.auth.backend.default', ''):
            # handled by deprecated_values
            pass
        elif old_value.find('airflow.api.auth.backend.session') == -1:
            new_value = old_value + "\nairflow.api.auth.backend.session"
            self._update_env_var(section="api", name="auth_backends", new_value=new_value)
            warnings.warn(
                'The auth_backends setting in [api] has had airflow.api.auth.backend.session added '
                'in the running config, which is needed by the UI. Please update your config before '
                'Apache Airflow 3.0.',
                FutureWarning,
            )

    def _validate_enums(self):
        """Validate that enum type config has an accepted value"""
        for (section_key, option_key), enum_options in self.enums_options.items():
            if self.has_option(section_key, option_key):
                value = self.get(section_key, option_key)
                if value not in enum_options:
                    raise AirflowConfigException(
                        f"`[{section_key}] {option_key}` should not be "
                        + f"{value!r}. Possible values: {', '.join(enum_options)}."
                    )

    def _validate_config_dependencies(self):
        """
        Validate that config values aren't invalid given other config values
        or system-level limitations and requirements.
        """
        is_executor_without_sqlite_support = self.get("core", "executor") not in (
            'DebugExecutor',
            'SequentialExecutor',
        )
        is_sqlite = "sqlite" in self.get('core', 'sql_alchemy_conn')
        if is_sqlite and is_executor_without_sqlite_support:
            raise AirflowConfigException(f"error: cannot use sqlite with the {self.get('core', 'executor')}")
        if is_sqlite:
            import sqlite3

            from airflow.utils.docs import get_docs_url

            # Some of the features in storing rendered fields require sqlite version >= 3.15.0
            min_sqlite_version = (3, 15, 0)
            if _parse_sqlite_version(sqlite3.sqlite_version) < min_sqlite_version:
                min_sqlite_version_str = ".".join(str(s) for s in min_sqlite_version)
                raise AirflowConfigException(
                    f"error: sqlite C library version too old (< {min_sqlite_version_str}). "
                    f"See {get_docs_url('howto/set-up-database.html#setting-up-a-sqlite-database')}"
                )

    def _using_old_value(self, old, current_value):
        return old.search(current_value) is not None

    def _update_env_var(self, section, name, new_value):
        env_var = self._env_var_name(section, name)
        # If the config comes from environment, set it there so that any subprocesses keep the same override!
        if env_var in os.environ:
            os.environ[env_var] = new_value
            return
        if not self.has_section(section):
            self.add_section(section)
        self.set(section, name, new_value)

    @staticmethod
    def _create_future_warning(name, section, current_value, new_value, version):
        warnings.warn(
            f'The {name!r} setting in [{section}] has the old default value of {current_value!r}. '
            f'This value has been changed to {new_value!r} in the running config, but '
            f'please update your config before Apache Airflow {version}.',
            FutureWarning,
        )

    ENV_VAR_PREFIX = 'AIRFLOW__'

    def _env_var_name(self, section: str, key: str) -> str:
        return f'{self.ENV_VAR_PREFIX}{section.upper()}__{key.upper()}'

    def _get_env_var_option(self, section, key):
        # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
        env_var = self._env_var_name(section, key)
        if env_var in os.environ:
            return expand_env_var(os.environ[env_var])
        # alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
        env_var_cmd = env_var + '_CMD'
        if env_var_cmd in os.environ:
            # if this is a valid command key...
            if (section, key) in self.sensitive_config_values:
                return run_command(os.environ[env_var_cmd])
        # alternatively AIRFLOW__{SECTION}__{KEY}_SECRET (to get from Secrets Backend)
        env_var_secret_path = env_var + '_SECRET'
        if env_var_secret_path in os.environ:
            # if this is a valid secret path...
            if (section, key) in self.sensitive_config_values:
                return _get_config_value_from_secret_backend(os.environ[env_var_secret_path])
        return None

    def _get_cmd_option(self, section, key):
        fallback_key = key + '_cmd'
        # if this is a valid command key...
        if (section, key) in self.sensitive_config_values:
            if super().has_option(section, fallback_key):
                command = super().get(section, fallback_key)
                return run_command(command)
        return None

    def _get_secret_option(self, section, key):
        """Get Config option values from Secret Backend"""
        fallback_key = key + '_secret'
        # if this is a valid secret key...
        if (section, key) in self.sensitive_config_values:
            if super().has_option(section, fallback_key):
                secrets_path = super().get(section, fallback_key)
                return _get_config_value_from_secret_backend(secrets_path)
        return None

    def get(self, section, key, **kwargs):
        section = str(section).lower()
        key = str(key).lower()

        deprecated_section, deprecated_key, _ = self.deprecated_options.get(
            (section, key), (None, None, None)
        )

        option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)
        if option is not None:
            return option

        option = self._get_option_from_config_file(deprecated_key, deprecated_section, key, kwargs, section)
        if option is not None:
            return option

        option = self._get_option_from_commands(deprecated_key, deprecated_section, key, section)
        if option is not None:
            return option

        option = self._get_option_from_secrets(deprecated_key, deprecated_section, key, section)
        if option is not None:
            return option

        return self._get_option_from_default_config(section, key, **kwargs)

    def _get_option_from_default_config(self, section, key, **kwargs):
        # ...then the default config
        if self.airflow_defaults.has_option(section, key) or 'fallback' in kwargs:
            return expand_env_var(self.airflow_defaults.get(section, key, **kwargs))

        else:
            log.warning("section/key [%s/%s] not found in config", section, key)

            raise AirflowConfigException(f"section/key [{section}/{key}] not found in config")

    def _get_option_from_secrets(self, deprecated_key, deprecated_section, key, section):
        # ...then from secret backends
        option = self._get_secret_option(section, key)
        if option:
            return option
        if deprecated_section:
            option = self._get_secret_option(deprecated_section, deprecated_key)
            if option:
                self._warn_deprecate(section, key, deprecated_section, deprecated_key)
                return option
        return None

    def _get_option_from_commands(self, deprecated_key, deprecated_section, key, section):
        # ...then commands
        option = self._get_cmd_option(section, key)
        if option:
            return option
        if deprecated_section:
            option = self._get_cmd_option(deprecated_section, deprecated_key)
            if option:
                self._warn_deprecate(section, key, deprecated_section, deprecated_key)
                return option
        return None

    def _get_option_from_config_file(self, deprecated_key, deprecated_section, key, kwargs, section):
        # ...then the config file
        if super().has_option(section, key):
            # Use the parent's methods to get the actual config here to be able to
            # separate the config from default config.
            return expand_env_var(super().get(section, key, **kwargs))
        if deprecated_section:
            if super().has_option(deprecated_section, deprecated_key):
                self._warn_deprecate(section, key, deprecated_section, deprecated_key)
                return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs))
        return None

    def _get_environment_variables(self, deprecated_key, deprecated_section, key, section):
        # first check environment variables
        option = self._get_env_var_option(section, key)
        if option is not None:
            return option
        if deprecated_section:
            option = self._get_env_var_option(deprecated_section, deprecated_key)
            if option is not None:
                self._warn_deprecate(section, key, deprecated_section, deprecated_key)
                return option
        return None

    def getboolean(self, section, key, **kwargs):
        val = str(self.get(section, key, **kwargs)).lower().strip()
        if '#' in val:
            val = val.split('#')[0].strip()
        if val in ('t', 'true', '1'):
            return True
        elif val in ('f', 'false', '0'):
            return False
        else:
            raise AirflowConfigException(
                f'Failed to convert value to bool. Please check "{key}" key in "{section}" section. '
                f'Current value: "{val}".'
            )

    def getint(self, section, key, **kwargs):
        val = self.get(section, key, **kwargs)

        try:
            return int(val)
        except ValueError:
            raise AirflowConfigException(
                f'Failed to convert value to int. Please check "{key}" key in "{section}" section. '
                f'Current value: "{val}".'
            )

    def getfloat(self, section, key, **kwargs):
        val = self.get(section, key, **kwargs)

        try:
            return float(val)
        except ValueError:
            raise AirflowConfigException(
                f'Failed to convert value to float. Please check "{key}" key in "{section}" section. '
                f'Current value: "{val}".'
            )

    def getimport(self, section, key, **kwargs):
        """
        Reads options, imports the full qualified name, and returns the object.

        In case of failure, it throws an exception with the key and section names

        :return: The object or None, if the option is empty
        """
        full_qualified_path = conf.get(section=section, key=key, **kwargs)
        if not full_qualified_path:
            return None

        try:
            return import_string(full_qualified_path)
        except ImportError as e:
            log.error(e)
            raise AirflowConfigException(
                f'The object could not be loaded. Please check "{key}" key in "{section}" section. '
                f'Current value: "{full_qualified_path}".'
            )

    def getjson(self, section, key, fallback=_UNSET, **kwargs) -> Union[dict, list, str, int, float, None]:
        """
        Return a config value parsed from a JSON string.

        ``fallback`` is *not* JSON parsed but used verbatim when no config value is given.
        """
        # get always returns the fallback value as a string, so for this if
        # someone gives us an object we want to keep that
        default = _UNSET
        if fallback is not _UNSET:
            default = fallback
            fallback = _UNSET

        try:
            data = self.get(section=section, key=key, fallback=fallback, **kwargs)
        except (NoSectionError, NoOptionError):
            return default

        if len(data) == 0:
            return default if default is not _UNSET else None

        try:
            return json.loads(data)
        except JSONDecodeError as e:
            raise AirflowConfigException(f'Unable to parse [{section}] {key!r} as valid json') from e

    def read(self, filenames, encoding=None):
        super().read(filenames=filenames, encoding=encoding)

    def read_dict(self, dictionary, source='<dict>'):
        super().read_dict(dictionary=dictionary, source=source)

    def has_option(self, section, option):
        try:
            # Using self.get() to avoid reimplementing the priority order
            # of config variables (env, config, cmd, defaults)
            # UNSET to avoid logging a warning about missing values
            self.get(section, option, fallback=_UNSET)
            return True
        except (NoOptionError, NoSectionError):
            return False

    def remove_option(self, section, option, remove_default=True):
        """
        Remove an option if it exists in config from a file or
        default config. If both of config have the same option, this removes
        the option in both configs unless remove_default=False.
        """
        if super().has_option(section, option):
            super().remove_option(section, option)

        if self.airflow_defaults.has_option(section, option) and remove_default:
            self.airflow_defaults.remove_option(section, option)

    def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float, bool]]]:
        """
        Returns the section as a dict. Values are converted to int, float, bool
        as required.

        :param section: section from the config
        :rtype: dict
        """
        if not self.has_section(section) and not self.airflow_defaults.has_section(section):
            return None

        if self.airflow_defaults.has_section(section):
            _section = OrderedDict(self.airflow_defaults.items(section))
        else:
            _section = OrderedDict()

        if self.has_section(section):
            _section.update(OrderedDict(self.items(section)))

        section_prefix = self._env_var_name(section, '')
        for env_var in sorted(os.environ.keys()):
            if env_var.startswith(section_prefix):
                key = env_var.replace(section_prefix, '')
                if key.endswith("_CMD"):
                    key = key[:-4]
                key = key.lower()
                _section[key] = self._get_env_var_option(section, key)

        for key, val in _section.items():
            try:
                val = int(val)
            except ValueError:
                try:
                    val = float(val)
                except ValueError:
                    if val.lower() in ('t', 'true'):
                        val = True
                    elif val.lower() in ('f', 'false'):
                        val = False
            _section[key] = val
        return _section

    def write(self, fp, space_around_delimiters=True):
        # This is based on the configparser.RawConfigParser.write method code to add support for
        # reading options from environment variables.
        if space_around_delimiters:
            delimiter = f" {self._delimiters[0]} "
        else:
            delimiter = self._delimiters[0]
        if self._defaults:
            self._write_section(fp, self.default_section, self._defaults.items(), delimiter)
        for section in self._sections:
            self._write_section(fp, section, self.getsection(section).items(), delimiter)

    def as_dict(
        self,
        display_source=False,
        display_sensitive=False,
        raw=False,
        include_env=True,
        include_cmds=True,
        include_secret=True,
    ) -> Dict[str, Dict[str, str]]:
        """
        Returns the current configuration as an OrderedDict of OrderedDicts.

        :param display_source: If False, the option value is returned. If True,
            a tuple of (option_value, source) is returned. Source is either
            'airflow.cfg', 'default', 'env var', or 'cmd'.
        :param display_sensitive: If True, the values of options set by env
            vars and bash commands will be displayed. If False, those options
            are shown as '< hidden >'
        :param raw: Should the values be output as interpolated values, or the
            "raw" form that can be fed back in to ConfigParser
        :param include_env: Should the value of configuration from AIRFLOW__
            environment variables be included or not
        :param include_cmds: Should the result of calling any *_cmd config be
            set (True, default), or should the _cmd options be left as the
            command to run (False)
        :param include_secret: Should the result of calling any *_secret config be
            set (True, default), or should the _secret options be left as the
            path to get the secret from (False)
        :rtype: Dict[str, Dict[str, str]]
        :return: Dictionary, where the key is the name of the section and the content is
            the dictionary with the name of the parameter and its value.
        """
        config_sources: Dict[str, Dict[str, str]] = {}
        configs = [
            ('default', self.airflow_defaults),
            ('airflow.cfg', self),
        ]

        self._replace_config_with_display_sources(config_sources, configs, display_source, raw)

        # add env vars and overwrite because they have priority
        if include_env:
            self._include_envs(config_sources, display_sensitive, display_source, raw)

        # add bash commands
        if include_cmds:
            self._include_commands(config_sources, display_sensitive, display_source, raw)

        # add config from secret backends
        if include_secret:
            self._include_secrets(config_sources, display_sensitive, display_source, raw)
        return config_sources

    def _include_secrets(self, config_sources, display_sensitive, display_source, raw):
        for (section, key) in self.sensitive_config_values:
            opt = self._get_secret_option(section, key)
            if opt:
                if not display_sensitive:
                    opt = '< hidden >'
                if display_source:
                    opt = (opt, 'secret')
                elif raw:
                    opt = opt.replace('%', '%%')
                config_sources.setdefault(section, OrderedDict()).update({key: opt})
                del config_sources[section][key + '_secret']

    def _include_commands(self, config_sources, display_sensitive, display_source, raw):
        for (section, key) in self.sensitive_config_values:
            opt = self._get_cmd_option(section, key)
            if not opt:
                continue
            if not display_sensitive:
                opt = '< hidden >'
            if display_source:
                opt = (opt, 'cmd')
            elif raw:
                opt = opt.replace('%', '%%')
            config_sources.setdefault(section, OrderedDict()).update({key: opt})
            del config_sources[section][key + '_cmd']

    def _include_envs(self, config_sources, display_sensitive, display_source, raw):
        for env_var in [
            os_environment for os_environment in os.environ if os_environment.startswith(self.ENV_VAR_PREFIX)
        ]:
            try:
                _, section, key = env_var.split('__', 2)
                opt = self._get_env_var_option(section, key)
            except ValueError:
                continue
            if opt is None:
                log.warning("Ignoring unknown env var '%s'", env_var)
                continue
            if not display_sensitive and env_var != self._env_var_name('core', 'unit_test_mode'):
                opt = '< hidden >'
            elif raw:
                opt = opt.replace('%', '%%')
            if display_source:
                opt = (opt, 'env var')

            section = section.lower()
            # if we lower key for kubernetes_environment_variables section,
            # then we won't be able to set any Airflow environment
            # variables. Airflow only parse environment variables starts
            # with AIRFLOW_. Therefore, we need to make it a special case.
            if section != 'kubernetes_environment_variables':
                key = key.lower()
            config_sources.setdefault(section, OrderedDict()).update({key: opt})

    @staticmethod
    def _replace_config_with_display_sources(config_sources, configs, display_source, raw):
        for (source_name, config) in configs:
            for section in config.sections():
                AirflowConfigParser._replace_section_config_with_display_sources(
                    config, config_sources, display_source, raw, section, source_name
                )

    @staticmethod
    def _replace_section_config_with_display_sources(
        config, config_sources, display_source, raw, section, source_name
    ):
        sect = config_sources.setdefault(section, OrderedDict())
        for (k, val) in config.items(section=section, raw=raw):
            if display_source:
                val = (val, source_name)
            sect[k] = val

    def load_test_config(self):
        """
        Load the unit test configuration.

        Note: this is not reversible.
        """
        # remove all sections, falling back to defaults
        for section in self.sections():
            self.remove_section(section)

        # then read test config

        path = _default_config_file_path('default_test.cfg')
        log.info("Reading default test configuration from %s", path)
        self.read_string(_parameterized_config_from_template('default_test.cfg'))
        # then read any "custom" test settings
        log.info("Reading test configuration from %s", TEST_CONFIG_FILE)
        self.read(TEST_CONFIG_FILE)

    @staticmethod
    def _warn_deprecate(section, key, deprecated_section, deprecated_name):
        if section == deprecated_section:
            warnings.warn(
                f'The {deprecated_name} option in [{section}] has been renamed to {key} - '
                f'the old setting has been used, but please update your config.',
                DeprecationWarning,
                stacklevel=3,
            )
        else:
            warnings.warn(
                f'The {deprecated_name} option in [{deprecated_section}] has been moved to the {key} option '
                f'in [{section}] - the old setting has been used, but please update your config.',
                DeprecationWarning,
                stacklevel=3,
            )

    def __getstate__(self):
        return {
            name: getattr(self, name)
            for name in [
                '_sections',
                'is_validated',
                'airflow_defaults',
            ]
        }

    def __setstate__(self, state):
        self.__init__()
        config = state.pop('_sections')
        self.read_dict(config)
        self.__dict__.update(state)
 def test_valid_weight_rules(self):
     self.assertTrue(WeightRule.is_valid(WeightRule.DOWNSTREAM))
     self.assertTrue(WeightRule.is_valid(WeightRule.UPSTREAM))
     self.assertTrue(WeightRule.is_valid(WeightRule.ABSOLUTE))
     self.assertEqual(len(WeightRule.all_weight_rules()), 3)
Exemplo n.º 7
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[str] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = conf.getint('core',
                                                 'default_task_retries',
                                                 fallback=0),
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag: Optional[DAG] = None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Dict] = None,
            outlets: Optional[Dict] = None,
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(
            *resources) if resources is not None else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: List[DataSet]
        self.outlets = []  # type: List[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
    def __init__(
        self,
        task_id,  # type: str
        owner=configuration.conf.get('operators', 'DEFAULT_OWNER'),  # type: str
        email=None,  # type: Optional[str]
        email_on_retry=True,  # type: bool
        email_on_failure=True,  # type: bool
        retries=0,  # type: int
        retry_delay=timedelta(seconds=300),  # type: timedelta
        retry_exponential_backoff=False,  # type: bool
        max_retry_delay=None,  # type: Optional[datetime]
        start_date=None,  # type: Optional[datetime]
        end_date=None,  # type: Optional[datetime]
        schedule_interval=None,  # not hooked as of now
        depends_on_past=False,  # type: bool
        wait_for_downstream=False,  # type: bool
        dag=None,  # type: Optional[DAG]
        params=None,  # type: Optional[Dict]
        default_args=None,  # type: Optional[Dict]
        priority_weight=1,  # type: int
        weight_rule=WeightRule.DOWNSTREAM,  # type: str
        queue=configuration.conf.get('celery', 'default_queue'),  # type: str
        pool=None,  # type: Optional[str]
        sla=None,  # type: Optional[timedelta]
        execution_timeout=None,  # type: Optional[timedelta]
        on_failure_callback=None,  # type: Optional[Callable]
        on_success_callback=None,  # type: Optional[Callable]
        on_retry_callback=None,  # type: Optional[Callable]
        trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
        resources=None,  # type: Optional[Dict]
        run_as_user=None,  # type: Optional[str]
        task_concurrency=None,  # type: Optional[int]
        executor_config=None,  # type: Optional[Dict]
        do_xcom_push=True,  # type: bool
        inlets=None,  # type: Optional[Dict]
        outlets=None,  # type: Optional[Dict]
        *args,
        **kwargs
    ):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(
                    c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3
            )
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_triggers=TriggerRule.all_triggers(),
                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead",
                self
            )
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback
        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_weight_rules=WeightRule.all_weight_rules,
                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(**(resources or {}))
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: Iterable[DataSet]
        self.outlets = []  # type: Iterable[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)

        self._comps = {
            'task_id',
            'dag_id',
            'owner',
            'email',
            'email_on_retry',
            'retry_delay',
            'retry_exponential_backoff',
            'max_retry_delay',
            'start_date',
            'schedule_interval',
            'depends_on_past',
            'wait_for_downstream',
            'priority_weight',
            'sla',
            'execution_timeout',
            'on_failure_callback',
            'on_success_callback',
            'on_retry_callback',
            'do_xcom_push',
        }
Exemplo n.º 9
0
 def test_valid_weight_rules(self):
     self.assertTrue(WeightRule.is_valid(WeightRule.DOWNSTREAM))
     self.assertTrue(WeightRule.is_valid(WeightRule.UPSTREAM))
     self.assertTrue(WeightRule.is_valid(WeightRule.ABSOLUTE))
     self.assertEqual(len(WeightRule.all_weight_rules()), 3)
Exemplo n.º 10
0
 def test_valid_weight_rules(self):
     assert WeightRule.is_valid(WeightRule.DOWNSTREAM)
     assert WeightRule.is_valid(WeightRule.UPSTREAM)
     assert WeightRule.is_valid(WeightRule.ABSOLUTE)
     assert len(WeightRule.all_weight_rules()) == 3