def perform_krb181_workaround(principal: str): """ Workaround for Kerberos 1.8.1. :param principal: principal name :return: None """ cmdv: List[str] = [ conf.get_mandatory_value('kerberos', 'kinit_path'), "-c", conf.get_mandatory_value('kerberos', 'ccache'), "-R", ] # Renew ticket_cache log.info("Renewing kerberos ticket to work around kerberos 1.8.1: %s", " ".join(cmdv)) ret = subprocess.call(cmdv, close_fds=True) if ret != 0: principal = f"{principal or conf.get('kerberos', 'principal')}/{socket.getfqdn()}" ccache = conf.get('kerberos', 'ccache') log.error( "Couldn't renew kerberos ticket in order to work around Kerberos 1.8.1 issue. Please check that " "the ticket for '%s' is still renewable:\n $ kinit -f -c %s\nIf the 'renew until' date is the " "same as the 'valid starting' date, the ticket cannot be renewed. Please check your KDC " "configuration, and the ticket renewal policy (maxrenewlife) for the '%s' and `krbtgt' " "principals.", principal, ccache, principal, ) return ret
def send_mime_email( e_from: str, e_to: Union[str, List[str]], mime_msg: MIMEMultipart, conn_id: str = "smtp_default", dryrun: bool = False, ) -> None: """Send MIME email.""" smtp_host = conf.get_mandatory_value('smtp', 'SMTP_HOST') smtp_port = conf.getint('smtp', 'SMTP_PORT') smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS') smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL') smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT') smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT') smtp_user = None smtp_password = None if conn_id is not None: try: from airflow.hooks.base import BaseHook airflow_conn = BaseHook.get_connection(conn_id) smtp_user = airflow_conn.login smtp_password = airflow_conn.password except AirflowException: pass if smtp_user is None or smtp_password is None: warnings.warn( "Fetching SMTP credentials from configuration variables will be deprecated in a future " "release. Please set credentials using a connection instead.", PendingDeprecationWarning, stacklevel=2, ) try: smtp_user = conf.get('smtp', 'SMTP_USER') smtp_password = conf.get('smtp', 'SMTP_PASSWORD') except AirflowConfigException: log.debug( "No user/password found for SMTP, so logging in with no authentication." ) if not dryrun: for attempt in range(1, smtp_retry_limit + 1): log.info("Email alerting: attempt %s", str(attempt)) try: smtp_conn = _get_smtp_connection(smtp_host, smtp_port, smtp_timeout, smtp_ssl) except smtplib.SMTPServerDisconnected: if attempt < smtp_retry_limit: continue raise if smtp_starttls: smtp_conn.starttls() if smtp_user and smtp_password: smtp_conn.login(smtp_user, smtp_password) log.info("Sent an alert email to %s", e_to) smtp_conn.sendmail(e_from, e_to, mime_msg.as_string()) smtp_conn.quit() break
def build_airflow_url_with_query(query: Dict[str, Any]) -> str: """ Build airflow url using base_url and default_view and provided query For example: 'http://0.0.0.0:8000/base/graph?dag_id=my-task&root=&execution_date=2020-10-27T10%3A59%3A25.615587 """ import flask view = conf.get_mandatory_value('webserver', 'dag_default_view').lower() return flask.url_for(f"Airflow.{view}", **query)
def on_kill(self) -> None: """Kill Spark submit command""" self.log.debug("Kill Command is being called") if self._should_track_driver_status: if self._driver_id: self.log.info('Killing driver %s on cluster', self._driver_id) kill_cmd = self._build_spark_driver_kill_command() with subprocess.Popen( kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as driver_kill: self.log.info( "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() ) if self._submit_sp and self._submit_sp.poll() is None: self.log.info('Sending kill signal to %s', self._connection['spark_binary']) self._submit_sp.kill() if self._yarn_application_id: kill_cmd = f"yarn application -kill {self._yarn_application_id}".split() env = {**os.environ, **(self._env or {})} if self._keytab is not None and self._principal is not None: # we are ignoring renewal failures from renew_from_kt # here as the failure could just be due to a non-renewable ticket, # we still attempt to kill the yarn application renew_from_kt(self._principal, self._keytab, exit_on_fail=False) env = os.environ.copy() env["KRB5CCNAME"] = airflow_conf.get_mandatory_value('kerberos', 'ccache') with subprocess.Popen( kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as yarn_kill: self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) if self._kubernetes_driver_pod: self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod) # Currently only instantiate Kubernetes client for killing a spark pod. try: import kubernetes client = kube_client.get_kube_client() api_response = client.delete_namespaced_pod( self._kubernetes_driver_pod, self._connection['namespace'], body=kubernetes.client.V1DeleteOptions(), pretty=True, ) self.log.info("Spark on K8s killed with response: %s", api_response) except kube_client.ApiException: self.log.exception("Exception when attempting to kill Spark on K8s")
def get_default_executor(cls) -> "BaseExecutor": """Creates a new instance of the configured executor if none exists and returns it""" if cls._default_executor is not None: return cls._default_executor from airflow.configuration import conf executor_name = conf.get_mandatory_value('core', 'EXECUTOR') cls._default_executor = cls.load_executor(executor_name) return cls._default_executor
def detect_conf_var() -> bool: """Return true if the ticket cache contains "conf" information as is found in ticket caches of Kerberos 1.8.1 or later. This is incompatible with the Sun Java Krb5LoginModule in Java6, so we need to take an action to work around it. """ ticket_cache = conf.get_mandatory_value('kerberos', 'ccache') with open(ticket_cache, 'rb') as file: # Note: this file is binary, so we check against a bytearray. return b'X-CACHECONF:' in file.read()
def _get_multiprocessing_start_method(self) -> str: """ Determine method of creating new processes by checking if the mp_start_method is set in configs, else, it uses the OS default. """ if conf.has_option('core', 'mp_start_method'): return conf.get_mandatory_value('core', 'mp_start_method') method = multiprocessing.get_start_method() if not method: raise ValueError("Failed to determine start method") return method
def get_current_api_client() -> Client: """Return current API Client based on current Airflow configuration""" api_module = import_module(conf.get_mandatory_value( 'cli', 'api_client')) # type: Any auth_backends = api.load_auth() session = None for backend in auth_backends: session_factory = getattr(backend, 'create_client_session', None) if session_factory: session = session_factory() api_client = api_module.Client( api_base_url=conf.get('cli', 'endpoint_url'), auth=getattr(backend, 'CLIENT_AUTH', None), session=session, ) return api_client
def __init__( self, *, poke_interval: float = 60, timeout: float = conf.getfloat('sensors', 'default_timeout'), soft_fail: bool = False, mode: str = 'poke', exponential_backoff: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) self.poke_interval = poke_interval self.soft_fail = soft_fail self.timeout = timeout self.mode = mode self.exponential_backoff = exponential_backoff self._validate_input_values() self.sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') self.sensors_support_sensor_service = set( map(lambda l: l.strip(), conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(',')) )
def find_path_from_directory( base_dir_path: str, ignore_file_name: str, ignore_file_syntax: str = conf.get_mandatory_value( 'core', 'DAG_IGNORE_FILE_SYNTAX', fallback="regexp"), ) -> Generator[str, None, None]: """ Recursively search the base path and return the list of file paths that should not be ignored. :param base_dir_path: the base path to be searched :param ignore_file_name: the file name in which specifies the patterns of files/dirs to be ignored :param ignore_file_syntax: the syntax of patterns in the ignore file: regexp or glob :return: a generator of file paths. """ if ignore_file_syntax == "glob": return _find_path_from_directory(base_dir_path, ignore_file_name, _GlobIgnoreRule) elif ignore_file_syntax == "regexp" or not ignore_file_syntax: return _find_path_from_directory(base_dir_path, ignore_file_name, _RegexpIgnoreRule) else: raise ValueError( f"Unsupported ignore_file_syntax: {ignore_file_syntax}")
def get_results( self, ti=None, fp=None, inline: bool = True, delim=None, fetch: bool = True, include_headers: bool = False, ) -> str: """ Get results (or just s3 locations) of a command from Qubole and save into a file :param ti: Task Instance of the dag, used to determine the Quboles command id :param fp: Optional file pointer, will create one and return if None passed :param inline: True to download actual results, False to get s3 locations only :param delim: Replaces the CTL-A chars with the given delim, defaults to ',' :param fetch: when inline is True, get results directly from s3 (if large) :return: file location containing actual results or s3 locations of results """ if fp is None: iso = datetime.datetime.utcnow().isoformat() logpath = os.path.expanduser(conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER')) resultpath = logpath + '/' + self.dag_id + '/' + self.task_id + '/results' pathlib.Path(resultpath).mkdir(parents=True, exist_ok=True) fp = open(resultpath + '/' + iso, 'wb') if self.cmd is None: cmd_id = ti.xcom_pull(key="qbol_cmd_id", task_ids=self.task_id) self.cmd = self.cls.find(cmd_id) include_headers_str = 'true' if include_headers else 'false' self.cmd.get_results( fp, inline, delim, fetch, arguments=[include_headers_str] ) # type: ignore[attr-defined] fp.flush() fp.close() return fp.name
from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule TaskStateChangeCallback = Callable[[Context], None] if TYPE_CHECKING: import jinja2 # Slow import. from sqlalchemy.orm import Session from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.models.dag import DAG from airflow.models.operator import Operator from airflow.models.taskinstance import TaskInstance DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner") DEFAULT_POOL_SLOTS: int = 1 DEFAULT_PRIORITY_WEIGHT: int = 1 DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue") DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST: bool = conf.getboolean( "scheduler", "ignore_first_depends_on_past_by_default") DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0) DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta( seconds=conf.getint("core", "default_task_retry_delay", fallback=300)) DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)) DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS DEFAULT_TASK_EXECUTION_TIMEOUT: Optional[ datetime.timedelta] = conf.gettimedelta("core",
from sqlalchemy.orm.session import Session as SASession from sqlalchemy.pool import NullPool from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf # NOQA F401 from airflow.executors import executor_constants from airflow.logging_config import configure_logging from airflow.utils.orm_event_handlers import setup_event_handlers if TYPE_CHECKING: from airflow.www.utils import UIAlert log = logging.getLogger(__name__) TIMEZONE = pendulum.tz.timezone('UTC') try: tz = conf.get_mandatory_value("core", "default_timezone") if tz == "system": TIMEZONE = pendulum.tz.local_timezone() else: TIMEZONE = pendulum.tz.timezone(tz) except Exception: pass log.info("Configured default timezone %s", TIMEZONE) HEADER = '\n'.join([ r' ____________ _____________', r' ____ |__( )_________ __/__ /________ __', r'____ /| |_ /__ ___/_ /_ __ /_ __ \_ | /| / /', r'___ ___ | / _ / _ __/ _ / / /_/ /_ |/ |/ /', r' _/_/ |_/_/ /_/ /_/ /_/ \____/____/|__/', ])
def __init__( self, dag_directory: Union[str, "pathlib.Path"], max_runs: int, processor_timeout: timedelta, dag_ids: Optional[List[str]], pickle_dags: bool, signal_conn: Optional[MultiprocessingConnection] = None, async_mode: bool = True, ): super().__init__() self._file_paths: List[str] = [] self._file_path_queue: List[str] = [] self._dag_directory = dag_directory self._max_runs = max_runs # signal_conn is None for dag_processor_standalone mode. self._direct_scheduler_conn = signal_conn self._pickle_dags = pickle_dags self._dag_ids = dag_ids self._async_mode = async_mode self._parsing_start_time: Optional[int] = None # Set the signal conn in to non-blocking mode, so that attempting to # send when the buffer is full errors, rather than hangs for-ever # attempting to send (this is to avoid deadlocks!) # # Don't do this in sync_mode, as we _need_ the DagParsingStat sent to # continue the scheduler if self._async_mode and self._direct_scheduler_conn is not None: os.set_blocking(self._direct_scheduler_conn.fileno(), False) self._parallelism = conf.getint('scheduler', 'parsing_processes') if (conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite') and self._parallelism > 1): self.log.warning( "Because we cannot use more than 1 thread (parsing_processes = " "%d) when using sqlite. So we set parallelism to 1.", self._parallelism, ) self._parallelism = 1 # Parse and schedule each file no faster than this interval. self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval') # How often to print out DAG file processing stats to the log. Default to # 30 seconds. self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval') # Map from file path to the processor self._processors: Dict[str, DagFileProcessorProcess] = {} self._num_run = 0 # Map from file path to stats about the file self._file_stats: Dict[str, DagFileStat] = {} # Last time that the DAG dir was traversed to look for files self.last_dag_dir_refresh_time = timezone.make_aware( datetime.fromtimestamp(0)) # Last time stats were printed self.last_stat_print_time = 0 # Last time we cleaned up DAGs which are no longer in files self.last_deactivate_stale_dags_time = timezone.make_aware( datetime.fromtimestamp(0)) # How often to check for DAGs which are no longer in files self.deactivate_stale_dags_interval = conf.getint( 'scheduler', 'deactivate_stale_dags_interval') # How long to wait before timing out a process to parse a DAG file self._processor_timeout = processor_timeout # How often to scan the DAGs directory for new files. Default to 5 minutes. self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval') # Mapping file name and callbacks requests self._callback_to_execute: Dict[ str, List[CallbackRequest]] = defaultdict(list) self._log = logging.getLogger('airflow.processor_manager') self.waitables: Dict[Any, Union[MultiprocessingConnection, DagFileProcessorProcess]] = ({ self._direct_scheduler_conn: self._direct_scheduler_conn, } if self._direct_scheduler_conn is not None else {})
from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_ from sqlalchemy.dialects import mssql, mysql from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session from sqlalchemy.sql import ColumnElement from sqlalchemy.sql.expression import ColumnOperators from sqlalchemy.types import JSON, Text, TypeDecorator, TypeEngine, UnicodeText from airflow import settings from airflow.configuration import conf log = logging.getLogger(__name__) utc = pendulum.tz.timezone('UTC') using_mysql = conf.get_mandatory_value( 'database', 'sql_alchemy_conn').lower().startswith('mysql') class UtcDateTime(TypeDecorator): """ Almost equivalent to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, but it differs from that by: - Never silently take naive :class:`~datetime.datetime`, instead it always raise :exc:`ValueError` unless time zone aware value. - :class:`~datetime.datetime` value's :attr:`~datetime.datetime.tzinfo` is always converted to UTC. - Unlike SQLAlchemy's built-in :class:`~sqlalchemy.types.TIMESTAMP`, it never return naive :class:`~datetime.datetime`, but time zone aware value, even with SQLite or MySQL. - Always returns TIMESTAMP in UTC
def renew_from_kt(principal: Optional[str], keytab: str, exit_on_fail: bool = True): """ Renew kerberos token from keytab :param principal: principal :param keytab: keytab file :return: None """ # The config is specified in seconds. But we ask for that same amount in # minutes to give ourselves a large renewal buffer. renewal_lifetime = f"{conf.getint('kerberos', 'reinit_frequency')}m" cmd_principal = principal or conf.get_mandatory_value( 'kerberos', 'principal').replace("_HOST", socket.getfqdn()) if conf.getboolean('kerberos', 'forwardable'): forwardable = '-f' else: forwardable = '-F' if conf.getboolean('kerberos', 'include_ip'): include_ip = '-a' else: include_ip = '-A' cmdv: List[str] = [ conf.get_mandatory_value('kerberos', 'kinit_path'), forwardable, include_ip, "-r", renewal_lifetime, "-k", # host ticket "-t", keytab, # specify keytab "-c", conf.get_mandatory_value('kerberos', 'ccache'), # specify credentials cache cmd_principal, ] log.info("Re-initialising kerberos from keytab: %s", " ".join(shlex.quote(f) for f in cmdv)) with subprocess.Popen( cmdv, stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True, bufsize=-1, universal_newlines=True, ) as subp: subp.wait() if subp.returncode != 0: log.error( "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s", subp.returncode, "\n".join(subp.stdout.readlines() if subp.stdout else []), "\n".join(subp.stderr.readlines() if subp.stderr else []), ) if exit_on_fail: sys.exit(subp.returncode) else: return subp.returncode global NEED_KRB181_WORKAROUND if NEED_KRB181_WORKAROUND is None: NEED_KRB181_WORKAROUND = detect_conf_var() if NEED_KRB181_WORKAROUND: # (From: HUE-640). Kerberos clock have seconds level granularity. Make sure we # renew the ticket after the initial valid time. time.sleep(1.5) ret = perform_krb181_workaround(cmd_principal) if exit_on_fail and ret != 0: sys.exit(ret) else: return ret return 0
def _broker_supports_visibility_timeout(url): return url.startswith("redis://") or url.startswith("sqs://") log = logging.getLogger(__name__) broker_url = conf.get('celery', 'BROKER_URL') broker_transport_options = conf.getsection( 'celery_broker_transport_options') or {} if 'visibility_timeout' not in broker_transport_options: if _broker_supports_visibility_timeout(broker_url): broker_transport_options['visibility_timeout'] = 21600 if conf.has_option("celery", 'RESULT_BACKEND'): result_backend = conf.get_mandatory_value('celery', 'RESULT_BACKEND') else: log.debug( "Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix." ) result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}' DEFAULT_CELERY_CONFIG = { 'accept_content': ['json'], 'event_serializer': 'json', 'worker_prefetch_multiplier': conf.getint('celery', 'worker_prefetch_multiplier'), 'task_acks_late': True, 'task_default_queue':
# under the License. """Airflow logging settings""" import os from pathlib import Path from typing import Any, Dict, Optional, Union from urllib.parse import urlparse from airflow.configuration import conf from airflow.exceptions import AirflowException # TODO: Logging format and level should be configured # in this file instead of from airflow.cfg. Currently # there are other log format and level configurations in # settings.py and cli.py. Please see AIRFLOW-1455. LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper() # Flask appbuilder's info level log is very verbose, # so it's set to 'WARN' by default. FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging', 'FAB_LOGGING_LEVEL').upper() LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT') LOG_FORMATTER_CLASS: str = conf.get_mandatory_value( 'logging', 'LOG_FORMATTER_CLASS', fallback='airflow.utils.log.timezone_aware.TimezoneAware') COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging', 'COLORED_LOG_FORMAT')
def get_token(obj: DagModel): """Return file token""" serializer = URLSafeSerializer(conf.get_mandatory_value('webserver', 'secret_key')) return serializer.dumps(obj.fileloc)