예제 #1
0
def perform_krb181_workaround(principal: str):
    """
    Workaround for Kerberos 1.8.1.

    :param principal: principal name
    :return: None
    """
    cmdv: List[str] = [
        conf.get_mandatory_value('kerberos', 'kinit_path'),
        "-c",
        conf.get_mandatory_value('kerberos', 'ccache'),
        "-R",
    ]  # Renew ticket_cache

    log.info("Renewing kerberos ticket to work around kerberos 1.8.1: %s",
             " ".join(cmdv))

    ret = subprocess.call(cmdv, close_fds=True)

    if ret != 0:
        principal = f"{principal or conf.get('kerberos', 'principal')}/{socket.getfqdn()}"
        ccache = conf.get('kerberos', 'ccache')
        log.error(
            "Couldn't renew kerberos ticket in order to work around Kerberos 1.8.1 issue. Please check that "
            "the ticket for '%s' is still renewable:\n  $ kinit -f -c %s\nIf the 'renew until' date is the "
            "same as the 'valid starting' date, the ticket cannot be renewed. Please check your KDC "
            "configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' "
            "principals.",
            principal,
            ccache,
            principal,
        )
    return ret
예제 #2
0
def send_mime_email(
    e_from: str,
    e_to: Union[str, List[str]],
    mime_msg: MIMEMultipart,
    conn_id: str = "smtp_default",
    dryrun: bool = False,
) -> None:
    """Send MIME email."""
    smtp_host = conf.get_mandatory_value('smtp', 'SMTP_HOST')
    smtp_port = conf.getint('smtp', 'SMTP_PORT')
    smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS')
    smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL')
    smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT')
    smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT')
    smtp_user = None
    smtp_password = None

    if conn_id is not None:
        try:
            from airflow.hooks.base import BaseHook

            airflow_conn = BaseHook.get_connection(conn_id)
            smtp_user = airflow_conn.login
            smtp_password = airflow_conn.password
        except AirflowException:
            pass
    if smtp_user is None or smtp_password is None:
        warnings.warn(
            "Fetching SMTP credentials from configuration variables will be deprecated in a future "
            "release. Please set credentials using a connection instead.",
            PendingDeprecationWarning,
            stacklevel=2,
        )
        try:
            smtp_user = conf.get('smtp', 'SMTP_USER')
            smtp_password = conf.get('smtp', 'SMTP_PASSWORD')
        except AirflowConfigException:
            log.debug(
                "No user/password found for SMTP, so logging in with no authentication."
            )

    if not dryrun:
        for attempt in range(1, smtp_retry_limit + 1):
            log.info("Email alerting: attempt %s", str(attempt))
            try:
                smtp_conn = _get_smtp_connection(smtp_host, smtp_port,
                                                 smtp_timeout, smtp_ssl)
            except smtplib.SMTPServerDisconnected:
                if attempt < smtp_retry_limit:
                    continue
                raise

            if smtp_starttls:
                smtp_conn.starttls()
            if smtp_user and smtp_password:
                smtp_conn.login(smtp_user, smtp_password)
            log.info("Sent an alert email to %s", e_to)
            smtp_conn.sendmail(e_from, e_to, mime_msg.as_string())
            smtp_conn.quit()
            break
예제 #3
0
def build_airflow_url_with_query(query: Dict[str, Any]) -> str:
    """
    Build airflow url using base_url and default_view and provided query
    For example:
    'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587
    """
    import flask

    view = conf.get_mandatory_value('webserver', 'dag_default_view').lower()
    return flask.url_for(f"Airflow.{view}", **query)
예제 #4
0
    def on_kill(self) -> None:
        """Kill Spark submit command"""
        self.log.debug("Kill Command is being called")

        if self._should_track_driver_status:
            if self._driver_id:
                self.log.info('Killing driver %s on cluster', self._driver_id)

                kill_cmd = self._build_spark_driver_kill_command()
                with subprocess.Popen(
                    kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
                ) as driver_kill:
                    self.log.info(
                        "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
                    )

        if self._submit_sp and self._submit_sp.poll() is None:
            self.log.info('Sending kill signal to %s', self._connection['spark_binary'])
            self._submit_sp.kill()

            if self._yarn_application_id:
                kill_cmd = f"yarn application -kill {self._yarn_application_id}".split()
                env = {**os.environ, **(self._env or {})}
                if self._keytab is not None and self._principal is not None:
                    # we are ignoring renewal failures from renew_from_kt
                    # here as the failure could just be due to a non-renewable ticket,
                    # we still attempt to kill the yarn application
                    renew_from_kt(self._principal, self._keytab, exit_on_fail=False)
                    env = os.environ.copy()
                    env["KRB5CCNAME"] = airflow_conf.get_mandatory_value('kerberos', 'ccache')

                with subprocess.Popen(
                    kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE
                ) as yarn_kill:
                    self.log.info("YARN app killed with return code: %s", yarn_kill.wait())

            if self._kubernetes_driver_pod:
                self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod)

                # Currently only instantiate Kubernetes client for killing a spark pod.
                try:
                    import kubernetes

                    client = kube_client.get_kube_client()
                    api_response = client.delete_namespaced_pod(
                        self._kubernetes_driver_pod,
                        self._connection['namespace'],
                        body=kubernetes.client.V1DeleteOptions(),
                        pretty=True,
                    )

                    self.log.info("Spark on K8s killed with response: %s", api_response)

                except kube_client.ApiException:
                    self.log.exception("Exception when attempting to kill Spark on K8s")
예제 #5
0
    def get_default_executor(cls) -> "BaseExecutor":
        """Creates a new instance of the configured executor if none exists and returns it"""
        if cls._default_executor is not None:
            return cls._default_executor

        from airflow.configuration import conf

        executor_name = conf.get_mandatory_value('core', 'EXECUTOR')
        cls._default_executor = cls.load_executor(executor_name)

        return cls._default_executor
예제 #6
0
def detect_conf_var() -> bool:
    """Return true if the ticket cache contains "conf" information as is found
    in ticket caches of Kerberos 1.8.1 or later. This is incompatible with the
    Sun Java Krb5LoginModule in Java6, so we need to take an action to work
    around it.
    """
    ticket_cache = conf.get_mandatory_value('kerberos', 'ccache')

    with open(ticket_cache, 'rb') as file:
        # Note: this file is binary, so we check against a bytearray.
        return b'X-CACHECONF:' in file.read()
예제 #7
0
    def _get_multiprocessing_start_method(self) -> str:
        """
        Determine method of creating new processes by checking if the
        mp_start_method is set in configs, else, it uses the OS default.
        """
        if conf.has_option('core', 'mp_start_method'):
            return conf.get_mandatory_value('core', 'mp_start_method')

        method = multiprocessing.get_start_method()
        if not method:
            raise ValueError("Failed to determine start method")
        return method
예제 #8
0
def get_current_api_client() -> Client:
    """Return current API Client based on current Airflow configuration"""
    api_module = import_module(conf.get_mandatory_value(
        'cli', 'api_client'))  # type: Any
    auth_backends = api.load_auth()
    session = None
    for backend in auth_backends:
        session_factory = getattr(backend, 'create_client_session', None)
        if session_factory:
            session = session_factory()
        api_client = api_module.Client(
            api_base_url=conf.get('cli', 'endpoint_url'),
            auth=getattr(backend, 'CLIENT_AUTH', None),
            session=session,
        )
    return api_client
예제 #9
0
파일: base.py 프로젝트: leahecole/airflow
 def __init__(
     self,
     *,
     poke_interval: float = 60,
     timeout: float = conf.getfloat('sensors', 'default_timeout'),
     soft_fail: bool = False,
     mode: str = 'poke',
     exponential_backoff: bool = False,
     **kwargs,
 ) -> None:
     super().__init__(**kwargs)
     self.poke_interval = poke_interval
     self.soft_fail = soft_fail
     self.timeout = timeout
     self.mode = mode
     self.exponential_backoff = exponential_backoff
     self._validate_input_values()
     self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor')
     self.sensors_support_sensor_service = set(
         map(lambda l: l.strip(), conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(','))
     )
예제 #10
0
def find_path_from_directory(
    base_dir_path: str,
    ignore_file_name: str,
    ignore_file_syntax: str = conf.get_mandatory_value(
        'core', 'DAG_IGNORE_FILE_SYNTAX', fallback="regexp"),
) -> Generator[str, None, None]:
    """
    Recursively search the base path and return the list of file paths that should not be ignored.
    :param base_dir_path: the base path to be searched
    :param ignore_file_name: the file name in which specifies the patterns of files/dirs to be ignored
    :param ignore_file_syntax: the syntax of patterns in the ignore file: regexp or glob

    :return: a generator of file paths.
    """
    if ignore_file_syntax == "glob":
        return _find_path_from_directory(base_dir_path, ignore_file_name,
                                         _GlobIgnoreRule)
    elif ignore_file_syntax == "regexp" or not ignore_file_syntax:
        return _find_path_from_directory(base_dir_path, ignore_file_name,
                                         _RegexpIgnoreRule)
    else:
        raise ValueError(
            f"Unsupported ignore_file_syntax: {ignore_file_syntax}")
예제 #11
0
    def get_results(
        self,
        ti=None,
        fp=None,
        inline: bool = True,
        delim=None,
        fetch: bool = True,
        include_headers: bool = False,
    ) -> str:
        """
        Get results (or just s3 locations) of a command from Qubole and save into a file

        :param ti: Task Instance of the dag, used to determine the Quboles command id
        :param fp: Optional file pointer, will create one and return if None passed
        :param inline: True to download actual results, False to get s3 locations only
        :param delim: Replaces the CTL-A chars with the given delim, defaults to ','
        :param fetch: when inline is True, get results directly from s3 (if large)
        :return: file location containing actual results or s3 locations of results
        """
        if fp is None:
            iso = datetime.datetime.utcnow().isoformat()
            logpath = os.path.expanduser(conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER'))
            resultpath = logpath + '/' + self.dag_id + '/' + self.task_id + '/results'
            pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True)
            fp = open(resultpath + '/' + iso, 'wb')

        if self.cmd is None:
            cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id)
            self.cmd = self.cls.find(cmd_id)

        include_headers_str = 'true' if include_headers else 'false'
        self.cmd.get_results(
            fp, inline, delim, fetch, arguments=[include_headers_str]
        )  # type: ignore[attr-defined]
        fp.flush()
        fp.close()
        return fp.name
예제 #12
0
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

TaskStateChangeCallback = Callable[[Context], None]

if TYPE_CHECKING:
    import jinja2  # Slow import.
    from sqlalchemy.orm import Session

    from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.models.taskinstance import TaskInstance

DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
DEFAULT_POOL_SLOTS: int = 1
DEFAULT_PRIORITY_WEIGHT: int = 1
DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean(
    "scheduler", "ignore_first_depends_on_past_by_default")
DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(
    seconds=conf.getint("core", "default_task_retry_delay", fallback=300))
DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
    conf.get("core",
             "default_task_weight_rule",
             fallback=WeightRule.DOWNSTREAM))
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
DEFAULT_TASK_EXECUTION_TIMEOUT: Optional[
    datetime.timedelta] = conf.gettimedelta("core",
예제 #13
0
from sqlalchemy.orm.session import Session as SASession
from sqlalchemy.pool import NullPool

from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf  # NOQA F401
from airflow.executors import executor_constants
from airflow.logging_config import configure_logging
from airflow.utils.orm_event_handlers import setup_event_handlers

if TYPE_CHECKING:
    from airflow.www.utils import UIAlert

log = logging.getLogger(__name__)

TIMEZONE = pendulum.tz.timezone('UTC')
try:
    tz = conf.get_mandatory_value("core", "default_timezone")
    if tz == "system":
        TIMEZONE = pendulum.tz.local_timezone()
    else:
        TIMEZONE = pendulum.tz.timezone(tz)
except Exception:
    pass
log.info("Configured default timezone %s", TIMEZONE)

HEADER = '\n'.join([
    r'  ____________       _____________',
    r' ____    |__( )_________  __/__  /________      __',
    r'____  /| |_  /__  ___/_  /_ __  /_  __ \_ | /| / /',
    r'___  ___ |  / _  /   _  __/ _  / / /_/ /_ |/ |/ /',
    r' _/_/  |_/_/  /_/    /_/    /_/  \____/____/|__/',
])
예제 #14
0
파일: manager.py 프로젝트: dskoda1/airflow
    def __init__(
        self,
        dag_directory: Union[str, "pathlib.Path"],
        max_runs: int,
        processor_timeout: timedelta,
        dag_ids: Optional[List[str]],
        pickle_dags: bool,
        signal_conn: Optional[MultiprocessingConnection] = None,
        async_mode: bool = True,
    ):
        super().__init__()
        self._file_paths: List[str] = []
        self._file_path_queue: List[str] = []
        self._dag_directory = dag_directory
        self._max_runs = max_runs
        # signal_conn is None for dag_processor_standalone mode.
        self._direct_scheduler_conn = signal_conn
        self._pickle_dags = pickle_dags
        self._dag_ids = dag_ids
        self._async_mode = async_mode
        self._parsing_start_time: Optional[int] = None

        # Set the signal conn in to non-blocking mode, so that attempting to
        # send when the buffer is full errors, rather than hangs for-ever
        # attempting to send (this is to avoid deadlocks!)
        #
        # Don't do this in sync_mode, as we _need_ the DagParsingStat sent to
        # continue the scheduler
        if self._async_mode and self._direct_scheduler_conn is not None:
            os.set_blocking(self._direct_scheduler_conn.fileno(), False)

        self._parallelism = conf.getint('scheduler', 'parsing_processes')
        if (conf.get_mandatory_value('database',
                                     'sql_alchemy_conn').startswith('sqlite')
                and self._parallelism > 1):
            self.log.warning(
                "Because we cannot use more than 1 thread (parsing_processes = "
                "%d) when using sqlite. So we set parallelism to 1.",
                self._parallelism,
            )
            self._parallelism = 1

        # Parse and schedule each file no faster than this interval.
        self._file_process_interval = conf.getint('scheduler',
                                                  'min_file_process_interval')
        # How often to print out DAG file processing stats to the log. Default to
        # 30 seconds.
        self.print_stats_interval = conf.getint('scheduler',
                                                'print_stats_interval')

        # Map from file path to the processor
        self._processors: Dict[str, DagFileProcessorProcess] = {}

        self._num_run = 0

        # Map from file path to stats about the file
        self._file_stats: Dict[str, DagFileStat] = {}

        # Last time that the DAG dir was traversed to look for files
        self.last_dag_dir_refresh_time = timezone.make_aware(
            datetime.fromtimestamp(0))
        # Last time stats were printed
        self.last_stat_print_time = 0
        # Last time we cleaned up DAGs which are no longer in files
        self.last_deactivate_stale_dags_time = timezone.make_aware(
            datetime.fromtimestamp(0))
        # How often to check for DAGs which are no longer in files
        self.deactivate_stale_dags_interval = conf.getint(
            'scheduler', 'deactivate_stale_dags_interval')
        # How long to wait before timing out a process to parse a DAG file
        self._processor_timeout = processor_timeout
        # How often to scan the DAGs directory for new files. Default to 5 minutes.
        self.dag_dir_list_interval = conf.getint('scheduler',
                                                 'dag_dir_list_interval')

        # Mapping file name and callbacks requests
        self._callback_to_execute: Dict[
            str, List[CallbackRequest]] = defaultdict(list)

        self._log = logging.getLogger('airflow.processor_manager')

        self.waitables: Dict[Any,
                             Union[MultiprocessingConnection,
                                   DagFileProcessorProcess]] = ({
                                       self._direct_scheduler_conn:
                                       self._direct_scheduler_conn,
                                   } if self._direct_scheduler_conn is not None
                                                                else {})
예제 #15
0
from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import ColumnElement
from sqlalchemy.sql.expression import ColumnOperators
from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText

from airflow import settings
from airflow.configuration import conf

log = logging.getLogger(__name__)

utc = pendulum.tz.timezone('UTC')

using_mysql = conf.get_mandatory_value(
    'database', 'sql_alchemy_conn').lower().startswith('mysql')


class UtcDateTime(TypeDecorator):
    """
    Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with
    ``timezone=True`` option, but it differs from that by:

    - Never silently take naive :class:`~datetime.datetime`, instead it
      always raise :exc:`ValueError` unless time zone aware value.
    - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo`
      is always converted to UTC.
    - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`,
      it never return naive :class:`~datetime.datetime`, but time zone
      aware value, even with SQLite or MySQL.
    - Always returns TIMESTAMP in UTC
예제 #16
0
def renew_from_kt(principal: Optional[str],
                  keytab: str,
                  exit_on_fail: bool = True):
    """
    Renew kerberos token from keytab

    :param principal: principal
    :param keytab: keytab file
    :return: None
    """
    # The config is specified in seconds. But we ask for that same amount in
    # minutes to give ourselves a large renewal buffer.
    renewal_lifetime = f"{conf.getint('kerberos', 'reinit_frequency')}m"

    cmd_principal = principal or conf.get_mandatory_value(
        'kerberos', 'principal').replace("_HOST", socket.getfqdn())

    if conf.getboolean('kerberos', 'forwardable'):
        forwardable = '-f'
    else:
        forwardable = '-F'

    if conf.getboolean('kerberos', 'include_ip'):
        include_ip = '-a'
    else:
        include_ip = '-A'

    cmdv: List[str] = [
        conf.get_mandatory_value('kerberos', 'kinit_path'),
        forwardable,
        include_ip,
        "-r",
        renewal_lifetime,
        "-k",  # host ticket
        "-t",
        keytab,  # specify keytab
        "-c",
        conf.get_mandatory_value('kerberos',
                                 'ccache'),  # specify credentials cache
        cmd_principal,
    ]
    log.info("Re-initialising kerberos from keytab: %s",
             " ".join(shlex.quote(f) for f in cmdv))

    with subprocess.Popen(
            cmdv,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            close_fds=True,
            bufsize=-1,
            universal_newlines=True,
    ) as subp:
        subp.wait()
        if subp.returncode != 0:
            log.error(
                "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s",
                subp.returncode,
                "\n".join(subp.stdout.readlines() if subp.stdout else []),
                "\n".join(subp.stderr.readlines() if subp.stderr else []),
            )
            if exit_on_fail:
                sys.exit(subp.returncode)
            else:
                return subp.returncode

    global NEED_KRB181_WORKAROUND
    if NEED_KRB181_WORKAROUND is None:
        NEED_KRB181_WORKAROUND = detect_conf_var()
    if NEED_KRB181_WORKAROUND:
        # (From: HUE-640). Kerberos clock have seconds level granularity. Make sure we
        # renew the ticket after the initial valid time.
        time.sleep(1.5)
        ret = perform_krb181_workaround(cmd_principal)
        if exit_on_fail and ret != 0:
            sys.exit(ret)
        else:
            return ret
    return 0
예제 #17
0
def _broker_supports_visibility_timeout(url):
    return url.startswith("redis://") or url.startswith("sqs://")


log = logging.getLogger(__name__)

broker_url = conf.get('celery', 'BROKER_URL')

broker_transport_options = conf.getsection(
    'celery_broker_transport_options') or {}
if 'visibility_timeout' not in broker_transport_options:
    if _broker_supports_visibility_timeout(broker_url):
        broker_transport_options['visibility_timeout'] = 21600

if conf.has_option("celery", 'RESULT_BACKEND'):
    result_backend = conf.get_mandatory_value('celery', 'RESULT_BACKEND')
else:
    log.debug(
        "Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix."
    )
    result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}'

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer':
    'json',
    'worker_prefetch_multiplier':
    conf.getint('celery', 'worker_prefetch_multiplier'),
    'task_acks_late':
    True,
    'task_default_queue':
# under the License.
"""Airflow logging settings"""

import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import urlparse

from airflow.configuration import conf
from airflow.exceptions import AirflowException

# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper()

# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging',
                                              'FAB_LOGGING_LEVEL').upper()

LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT')

LOG_FORMATTER_CLASS: str = conf.get_mandatory_value(
    'logging',
    'LOG_FORMATTER_CLASS',
    fallback='airflow.utils.log.timezone_aware.TimezoneAware')

COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging',
                                                   'COLORED_LOG_FORMAT')
예제 #19
0
 def get_token(obj: DagModel):
     """Return file token"""
     serializer = URLSafeSerializer(conf.get_mandatory_value('webserver', 'secret_key'))
     return serializer.dumps(obj.fileloc)