Beispiel #1
0
def backfill(args):
    logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
    dagbag = DagBag(args.subdir)
    if args.dag_id not in dagbag.dags:
        raise AirflowException("dag_id could not be found")
    dag = dagbag.dags[args.dag_id]

    if args.start_date:
        args.start_date = dateutil.parser.parse(args.start_date)
    if args.end_date:
        args.end_date = dateutil.parser.parse(args.end_date)

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(task_regex=args.task_regex, include_upstream=not args.ignore_dependencies)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        dag.run(
            start_date=args.start_date,
            end_date=args.end_date,
            mark_success=args.mark_success,
            include_adhoc=args.include_adhoc,
            local=args.local,
            donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")),
            ignore_dependencies=args.ignore_dependencies,
        )
Beispiel #2
0
def get_kube_client(in_cluster=conf.getboolean('kubernetes', 'in_cluster'),
                    cluster_context=None,
                    config_file=None):
    if not in_cluster:
        if cluster_context is None:
            cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None)
        if config_file is None:
            config_file = conf.get('kubernetes', 'config_file', fallback=None)
    return _load_kube_config(in_cluster, cluster_context, config_file)
Beispiel #3
0
def configure_orm(disable_connection_pool=False):
    log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
    global engine
    global Session
    engine_args = {}

    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.configure_orm(): Using NullPool")
    elif 'sqlite' not in SQL_ALCHEMY_CONN:
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        try:
            pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE')
        except conf.AirflowConfigException:
            pool_size = 5

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        try:
            pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE')
        except conf.AirflowConfigException:
            pool_recycle = 1800

        log.info("settings.configure_orm(): Using pool settings. pool_size={}, "
                 "pool_recycle={}, pid={}".format(pool_size, pool_recycle, os.getpid()))
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use
    # us.
    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
    # For Python2 we get back a newstr and need a str
    engine_args['encoding'] = engine_args['encoding'].__str__()

    engine = create_engine(SQL_ALCHEMY_CONN, **engine_args)
    reconnect_timeout = conf.getint('core', 'SQL_ALCHEMY_RECONNECT_TIMEOUT')
    setup_event_handlers(engine, reconnect_timeout)

    Session = scoped_session(
        sessionmaker(autocommit=False,
                     autoflush=False,
                     bind=engine,
                     expire_on_commit=False))
Beispiel #4
0
def validate_session():
    worker_precheck = conf.getboolean('core', 'worker_precheck', fallback=False)
    if not worker_precheck:
        return True
    else:
        check_session = sessionmaker(bind=engine)
        session = check_session()
        try:
            session.execute("select 1")
            conn_status = True
        except exc.DBAPIError as err:
            log.error(err)
            conn_status = False
        session.close()
        return conn_status
Beispiel #5
0
def send_MIME_email(e_from, e_to, mime_msg):
    SMTP_HOST = conf.get("smtp", "SMTP_HOST")
    SMTP_PORT = conf.getint("smtp", "SMTP_PORT")
    SMTP_USER = conf.get("smtp", "SMTP_USER")
    SMTP_PASSWORD = conf.get("smtp", "SMTP_PASSWORD")
    SMTP_STARTTLS = conf.getboolean("smtp", "SMTP_STARTTLS")

    s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
    if SMTP_STARTTLS:
        s.starttls()
    if SMTP_USER and SMTP_PASSWORD:
        s.login(SMTP_USER, SMTP_PASSWORD)
    logging.info("Sent an alert email to " + str(e_to))
    s.sendmail(e_from, e_to, mime_msg.as_string())
    s.quit()
Beispiel #6
0
def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
    SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
    SMTP_PORT = conf.getint('smtp', 'SMTP_PORT')
    SMTP_USER = conf.get('smtp', 'SMTP_USER')
    SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')
    SMTP_STARTTLS = conf.getboolean('smtp', 'SMTP_STARTTLS')

    if not dryrun:
        s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
        if SMTP_STARTTLS:
            s.starttls()
        if SMTP_USER and SMTP_PASSWORD:
            s.login(SMTP_USER, SMTP_PASSWORD)
        logging.info("Sent an alert email to " + str(e_to))
        s.sendmail(e_from, e_to, mime_msg.as_string())
        s.quit()
Beispiel #7
0
from time import sleep

from sqlalchemy import Column, Integer, String, DateTime, func, Index
from sqlalchemy.orm.session import make_transient

from airflow import executors, models, settings, utils
from airflow.configuration import conf
from airflow.utils import AirflowException, State


Base = models.Base
ID_LEN = models.ID_LEN

# Setting up a statsd client if needed
statsd = None
if conf.getboolean('scheduler', 'statsd_on'):
    from statsd import StatsClient
    statsd = StatsClient(
        host=conf.get('scheduler', 'statsd_host'),
        port=conf.getint('scheduler', 'statsd_port'),
        prefix=conf.get('scheduler', 'statsd_prefix'))


class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """
Beispiel #8
0
    def __init__(self):  # pylint: disable=too-many-statements
        configuration_dict = conf.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.kube_env_vars = configuration_dict.get('kubernetes_environment_variables', {})
        self.env_from_configmap_ref = conf.get(self.kubernetes_section,
                                               'env_from_configmap_ref')
        self.env_from_secret_ref = conf.get(self.kubernetes_section,
                                            'env_from_secret_ref')
        self.airflow_home = settings.AIRFLOW_HOME
        self.dags_folder = conf.get(self.core_section, 'dags_folder')
        self.parallelism = conf.getint(self.core_section, 'parallelism')
        self.worker_container_repository = conf.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = conf.get(
            self.kubernetes_section, 'worker_container_tag')
        self.kube_image = '{}:{}'.format(
            self.worker_container_repository, self.worker_container_tag)
        self.kube_image_pull_policy = conf.get(
            self.kubernetes_section, "worker_container_image_pull_policy"
        )
        self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
        self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file',
                                          fallback=None)

        kube_worker_annotations = conf.get(self.kubernetes_section, 'worker_annotations')
        if kube_worker_annotations:
            self.kube_annotations = json.loads(kube_worker_annotations)
        else:
            self.kube_annotations = None

        self.kube_labels = configuration_dict.get('kubernetes_labels', {})
        self.delete_worker_pods = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods')
        self.delete_worker_pods_on_failure = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods_on_failure')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section, 'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section, 'dags_in_image')

        # Run as user for pod security context
        self.worker_run_as_user = self._get_security_context_val('run_as_user')
        self.worker_fs_group = self._get_security_context_val('fs_group')

        kube_worker_resources = conf.get(self.kubernetes_section, 'worker_resources')
        if kube_worker_resources:
            self.worker_resources = json.loads(kube_worker_resources)
        else:
            self.worker_resources = None

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Clone depth for git sync
        self.git_sync_depth = conf.get(self.kubernetes_section, 'git_sync_depth')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, the tag or hash to checkout
        self.git_sync_rev = conf.get(self.kubernetes_section, 'git_sync_rev')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(self.kubernetes_section,
                                                    'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section, 'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(self.kubernetes_section,
                                                           'git_ssh_known_hosts_configmap_name')
        self.git_sync_credentials_secret = conf.get(self.kubernetes_section,
                                                    'git_sync_credentials_secret')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section, 'dags_volume_claim')

        self.dags_volume_mount_point = conf.get(self.kubernetes_section, 'dags_volume_mount_point')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section, 'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(
            self.kubernetes_section, 'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(
            self.kubernetes_section, 'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section, 'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section, 'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = conf.get(self.logging_section, 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, 'multi_namespace_mode')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(
            self.kubernetes_section, 'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        self.git_sync_run_as_user = self._get_security_context_val('git_sync_run_as_user')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section, 'airflow_configmap')

        # The worker pod may optionally have a valid Airflow local settings loaded via a
        # configmap
        self.airflow_local_settings_configmap = conf.get(
            self.kubernetes_section, 'airflow_local_settings_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        kube_client_request_args = conf.get(self.kubernetes_section, 'kube_client_request_args')
        if kube_client_request_args:
            self.kube_client_request_args = json.loads(kube_client_request_args)
            if self.kube_client_request_args['_request_timeout'] and \
                    isinstance(self.kube_client_request_args['_request_timeout'], list):
                self.kube_client_request_args['_request_timeout'] = \
                    tuple(self.kube_client_request_args['_request_timeout'])
        else:
            self.kube_client_request_args = {}
        self._validate()

        delete_option_kwargs = conf.get(self.kubernetes_section, 'delete_option_kwargs')
        if delete_option_kwargs:
            self.delete_option_kwargs = json.loads(delete_option_kwargs)
        else:
            self.delete_option_kwargs = {}
Beispiel #9
0
def configure_orm(disable_connection_pool=False):
    log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
    global engine
    global Session
    engine_args = {}

    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.configure_orm(): Using NullPool")
    elif 'sqlite' not in SQL_ALCHEMY_CONN:
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of “sleeping” connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('core', 'SQL_ALCHEMY_MAX_OVERFLOW', fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE', fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like “SELECT 1”, but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('core', 'SQL_ALCHEMY_POOL_PRE_PING', fallback=True)

        log.info("settings.configure_orm(): Using pool settings. pool_size={}, max_overflow={}, "
                 "pool_recycle={}, pid={}".format(pool_size, max_overflow, pool_recycle, os.getpid()))
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use
    # us.
    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
    # For Python2 we get back a newstr and need a str
    engine_args['encoding'] = engine_args['encoding'].__str__()

    engine = create_engine(SQL_ALCHEMY_CONN, **engine_args)
    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(autocommit=False,
                     autoflush=False,
                     bind=engine,
                     expire_on_commit=False))
    def collect_dags(
            self,
            dag_folder=None,
            only_if_updated=True,
            include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
            safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        start_dttm = timezone.utcnow()
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []
        FileLoadStat = namedtuple(
            'FileLoadStat', "file duration dag_num task_num dags")

        dag_folder = correct_maybe_zipped(dag_folder)

        for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode,
                                           include_examples=include_examples):
            try:
                ts = timezone.utcnow()
                found_dags = self.process_file(
                    filepath, only_if_updated=only_if_updated,
                    safe_mode=safe_mode)
                dag_ids = [dag.dag_id for dag in found_dags]
                dag_id_names = str(dag_ids)

                td = timezone.utcnow() - ts
                td = td.total_seconds() + (
                    float(td.microseconds) / 1000000)
                stats.append(FileLoadStat(
                    filepath.replace(dag_folder, ''),
                    td,
                    len(found_dags),
                    sum([len(dag.tasks) for dag in found_dags]),
                    dag_id_names,
                ))
            except Exception as e:
                self.log.exception(e)
        Stats.gauge(
            'collect_dags', (timezone.utcnow() - start_dttm).total_seconds(), 1)
        Stats.gauge(
            'dagbag_size', len(self.dags), 1)
        Stats.gauge(
            'dagbag_import_errors', len(self.import_errors), 1)
        self.dagbag_stats = sorted(
            stats, key=lambda x: x.duration, reverse=True)
        for file_stat in self.dagbag_stats:
            # file_stat.file similar format: /subdir/dag_name.py
            filename = file_stat.file.split('/')[-1].replace('.py', '')
            Stats.timing('dag.loading-duration.{}'.
                         format(filename),
                         file_stat.duration)
Beispiel #11
0
 def __init__(self):
     super().__init__()
     self.tasks_to_run: List[TaskInstance] = []
     # Place where we keep information for task instance raw run
     self.tasks_params: Dict[TaskInstanceKeyType, Dict[str, Any]] = {}
     self.fail_fast = conf.getboolean("debug", "fail_fast")
Beispiel #12
0
def prepare_engine_args(disable_connection_pool=False):
    """Prepare SQLAlchemy engine args"""
    engine_args = {}
    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.prepare_engine_args(): Using NullPool")
    elif not SQL_ALCHEMY_CONN.startswith('sqlite'):
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of “sleeping” connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('core',
                                   'SQL_ALCHEMY_MAX_OVERFLOW',
                                   fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('core',
                                   'SQL_ALCHEMY_POOL_RECYCLE',
                                   fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like “SELECT 1”, but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('core',
                                        'SQL_ALCHEMY_POOL_PRE_PING',
                                        fallback=True)

        log.debug(
            "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
            "pool_recycle=%d, pid=%d",
            pool_size,
            max_overflow,
            pool_recycle,
            os.getpid(),
        )
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # The default isolation level for MySQL (REPEATABLE READ) can introduce inconsistencies when
    # running multiple schedulers, as repeated queries on the same session may read from stale snapshots.
    # 'READ COMMITTED' is the default value for PostgreSQL.
    # More information here:
    # https://dev.mysql.com/doc/refman/8.0/en/innodb-transaction-isolation-levels.html"
    if SQL_ALCHEMY_CONN.startswith('mysql'):
        engine_args['isolation_level'] = 'READ COMMITTED'

    return engine_args
Beispiel #13
0
def create_app(config=None, testing=False, app_name="Airflow"):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        minutes=settings.get_session_lifetime_config())
    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = app_name
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get(
        'core', 'SQL_ALCHEMY_CONN')
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    flask_app.config['SESSION_COOKIE_SAMESITE'] = conf.get(
        'webserver', 'COOKIE_SAMESITE')

    if config:
        flask_app.config.from_mapping(config)

    if 'SQLALCHEMY_ENGINE_OPTIONS' not in flask_app.config:
        flask_app.config[
            'SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    csrf.init_app(flask_app)

    init_wsgi_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    Cache(app=flask_app,
          config={
              'CACHE_TYPE': 'filesystem',
              'CACHE_DIR': '/tmp'
          })

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_connection_form()
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)
    return flask_app
Beispiel #14
0
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'}

# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_update_interval', fallback=30)

# Fetching serialized DAG can not be faster than a minimum interval to reduce database
# read rate. This config controls when your DAGs are updated in the Webserver
MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_fetch_interval', fallback=10)

# Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
# from DB instead of trying to access files in a DAG folder.
STORE_DAG_CODE = conf.getboolean("core", "store_dag_code", fallback=True)

# If donot_modify_handlers=True, we do not modify logging handlers in task_run command
# If the flag is set to False, we remove all handlers from the root logger
# and add all handlers from 'airflow.task' logger to the root Logger. This is done
# to get all the logs from the print & log statements in the DAG files before a task is run
# The handlers are restored after the task completes execution.
DONOT_MODIFY_HANDLERS = conf.getboolean('logging',
                                        'donot_modify_handlers',
                                        fallback=False)

CAN_FORK = hasattr(os, "fork")

EXECUTE_TASKS_NEW_PYTHON_INTERPRETER = not CAN_FORK or conf.getboolean(
    'core',
    'execute_tasks_new_python_interpreter',
Beispiel #15
0
    def execute(self, context):
        if self.ssh_conn_id and not self.winrm_hook:
            self.log.info("Hook not found, creating...")
            self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id)

        if not self.winrm_hook:
            raise AirflowException(
                "Cannot operate without winrm_hook or ssh_conn_id.")

        if self.remote_host is not None:
            self.winrm_hook.remote_host = self.remote_host

        if not self.command:
            raise AirflowException(
                "No command specified so nothing to execute here.")

        winrm_client = self.winrm_hook.get_conn()

        try:
            self.log.info("Running command: '%s'...", self.command)
            command_id = self.winrm_hook.winrm_protocol.run_command(
                winrm_client, self.command)

            # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
            stdout_buffer = []
            stderr_buffer = []
            command_done = False
            while not command_done:
                try:
                    stdout, stderr, return_code, command_done = \
                        self.winrm_hook.winrm_protocol._raw_get_command_output(
                            winrm_client,
                            command_id
                        )

                    # Only buffer stdout if we need to so that we minimize memory usage.
                    if self.do_xcom_push:
                        stdout_buffer.append(stdout)
                    stderr_buffer.append(stderr)

                    for line in stdout.decode('utf-8').splitlines():
                        self.log.info(line)
                    for line in stderr.decode('utf-8').splitlines():
                        self.log.warning(line)
                except WinRMOperationTimeoutError:
                    # this is an expected error when waiting for a
                    # long-running process, just silently retry
                    pass

            self.winrm_hook.winrm_protocol.cleanup_command(
                winrm_client, command_id)
            self.winrm_hook.winrm_protocol.close_shell(winrm_client)

        except Exception as e:
            raise AirflowException("WinRM operator error: {0}".format(str(e)))

        if return_code == 0:
            # returning output if do_xcom_push is set
            enable_pickling = conf.getboolean('core', 'enable_xcom_pickling')
            if enable_pickling:
                return stdout_buffer
            else:
                return b64encode(b''.join(stdout_buffer)).decode('utf-8')
        else:
            error_msg = "Error running cmd: {0}, return code: {1}, error: {2}".format(
                self.command, return_code,
                b''.join(stderr_buffer).decode('utf-8'))
            raise AirflowException(error_msg)
Beispiel #16
0
    def __init__(
        self,
        task_id: str,
        owner: str = conf.get('operators', 'DEFAULT_OWNER'),
        email: Optional[Union[str, Iterable[str]]] = None,
        email_on_retry: bool = True,
        email_on_failure: bool = True,
        retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0),
        retry_delay: timedelta = timedelta(seconds=300),
        retry_exponential_backoff: bool = False,
        max_retry_delay: Optional[datetime] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        depends_on_past: bool = False,
        wait_for_downstream: bool = False,
        dag=None,
        params: Optional[Dict] = None,
        default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
        priority_weight: int = 1,
        weight_rule: str = WeightRule.DOWNSTREAM,
        queue: str = conf.get('celery', 'default_queue'),
        pool: str = Pool.DEFAULT_POOL_NAME,
        sla: Optional[timedelta] = None,
        execution_timeout: Optional[timedelta] = None,
        on_execute_callback: Optional[Callable] = None,
        on_failure_callback: Optional[Callable] = None,
        on_success_callback: Optional[Callable] = None,
        on_retry_callback: Optional[Callable] = None,
        trigger_rule: str = TriggerRule.ALL_SUCCESS,
        resources: Optional[Dict] = None,
        run_as_user: Optional[str] = None,
        task_concurrency: Optional[int] = None,
        executor_config: Optional[Dict] = None,
        do_xcom_push: bool = True,
        inlets: Optional[Any] = None,
        outlets: Optional[Any] = None,
        *args,
        **kwargs
    ):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(
                    c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3
            )
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_triggers=TriggerRule.all_triggers(),
                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_execute_callback = on_execute_callback
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_weight_rules=WeightRule.all_weight_rules,
                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(**resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        from airflow.models.dag import DAG
        self.subdag: Optional[DAG] = None

        self._log = logging.getLogger("airflow.task.operators")

        # Lineage
        self.inlets: List = []
        self.outlets: List = []

        self._inlets: List = []
        self._outlets: List = []

        if inlets:
            self._inlets = inlets if isinstance(inlets, list) else [inlets, ]

        if outlets:
            self._outlets = outlets if isinstance(outlets, list) else [outlets, ]
Beispiel #17
0
def webserver(args):
    """Starts Airflow Webserver"""
    print(settings.HEADER)

    access_logfile = args.access_logfile or conf.get('webserver',
                                                     'access_logfile')
    error_logfile = args.error_logfile or conf.get('webserver',
                                                   'error_logfile')
    num_workers = args.workers or conf.get('webserver', 'workers')
    worker_timeout = (args.worker_timeout
                      or conf.get('webserver', 'web_server_worker_timeout'))
    ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert')
    ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key')
    if not ssl_cert and ssl_key:
        raise AirflowException(
            'An SSL certificate must also be provided for use with ' + ssl_key)
    if ssl_cert and not ssl_key:
        raise AirflowException(
            'An SSL key must also be provided for use with ' + ssl_cert)

    if args.debug:
        print("Starting the web server on port {0} and host {1}.".format(
            args.port, args.hostname))
        app, _ = create_app(None,
                            testing=conf.getboolean('core', 'unit_test_mode'))
        app.run(debug=True,
                use_reloader=not app.config['TESTING'],
                port=args.port,
                host=args.hostname,
                ssl_context=(ssl_cert,
                             ssl_key) if ssl_cert and ssl_key else None)
    else:
        os.environ['SKIP_DAGS_PARSING'] = 'True'
        app = cached_app(None)
        pid, stdout, stderr, log_file = setup_locations(
            "webserver", args.pid, args.stdout, args.stderr, args.log_file)
        os.environ.pop('SKIP_DAGS_PARSING')
        if args.daemon:
            handle = setup_logging(log_file)
            stdout = open(stdout, 'w+')
            stderr = open(stderr, 'w+')

        print(
            textwrap.dedent('''\
                Running the Gunicorn Server with:
                Workers: {num_workers} {workerclass}
                Host: {hostname}:{port}
                Timeout: {worker_timeout}
                Logfiles: {access_logfile} {error_logfile}
                =================================================================\
            '''.format(num_workers=num_workers,
                       workerclass=args.workerclass,
                       hostname=args.hostname,
                       port=args.port,
                       worker_timeout=worker_timeout,
                       access_logfile=access_logfile,
                       error_logfile=error_logfile)))

        run_args = [
            'gunicorn',
            '-w',
            str(num_workers),
            '-k',
            str(args.workerclass),
            '-t',
            str(worker_timeout),
            '-b',
            args.hostname + ':' + str(args.port),
            '-n',
            'airflow-webserver',
            '-p',
            str(pid),
            '-c',
            'python:airflow.www.gunicorn_config',
        ]

        if args.access_logfile:
            run_args += ['--access-logfile', str(args.access_logfile)]

        if args.error_logfile:
            run_args += ['--error-logfile', str(args.error_logfile)]

        if args.daemon:
            run_args += ['-D']

        if ssl_cert:
            run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key]

        webserver_module = 'www'
        run_args += ["airflow." + webserver_module + ".app:cached_app()"]

        gunicorn_master_proc = None

        def kill_proc(dummy_signum, dummy_frame):  # pylint: disable=unused-argument
            gunicorn_master_proc.terminate()
            gunicorn_master_proc.wait()
            sys.exit(0)

        def monitor_gunicorn(gunicorn_master_proc):
            # These run forever until SIG{INT, TERM, KILL, ...} signal is sent
            if conf.getint('webserver', 'worker_refresh_interval') > 0:
                master_timeout = conf.getint('webserver',
                                             'web_server_master_timeout')
                restart_workers(gunicorn_master_proc, num_workers,
                                master_timeout)
            else:
                while gunicorn_master_proc.poll() is None:
                    time.sleep(1)

                sys.exit(gunicorn_master_proc.returncode)

        if args.daemon:
            base, ext = os.path.splitext(pid)
            ctx = daemon.DaemonContext(
                pidfile=TimeoutPIDLockFile(base + "-monitor" + ext, -1),
                files_preserve=[handle],
                stdout=stdout,
                stderr=stderr,
                signal_map={
                    signal.SIGINT: kill_proc,
                    signal.SIGTERM: kill_proc
                },
            )
            with ctx:
                subprocess.Popen(run_args, close_fds=True)

                # Reading pid file directly, since Popen#pid doesn't
                # seem to return the right value with DaemonContext.
                while True:
                    try:
                        with open(pid) as file:
                            gunicorn_master_proc_pid = int(file.read())
                            break
                    except OSError:
                        log.debug(
                            "Waiting for gunicorn's pid file to be created.")
                        time.sleep(0.1)

                gunicorn_master_proc = psutil.Process(gunicorn_master_proc_pid)
                monitor_gunicorn(gunicorn_master_proc)

            stdout.close()
            stderr.close()
        else:
            gunicorn_master_proc = subprocess.Popen(run_args, close_fds=True)

            signal.signal(signal.SIGINT, kill_proc)
            signal.signal(signal.SIGTERM, kill_proc)

            monitor_gunicorn(gunicorn_master_proc)
Beispiel #18
0
def dag_backfill(args, dag=None):
    """Creates backfill job or dry run for a DAG"""
    logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)

    signal.signal(signal.SIGTERM, sigint_handler)

    import warnings

    warnings.warn(
        '--ignore-first-depends-on-past is deprecated as the value is always set to True',
        category=PendingDeprecationWarning,
    )

    if args.ignore_first_depends_on_past is False:
        args.ignore_first_depends_on_past = True

    dag = dag or get_dag(args.subdir, args.dag_id)

    if not args.start_date and not args.end_date:
        raise AirflowException("Provide a start_date and/or end_date")

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.partial_subset(
            task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies
        )

    run_conf = None
    if args.conf:
        run_conf = json.loads(args.conf)

    if args.dry_run:
        print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
        for task in dag.tasks:
            print(f"Task {task.task_id}")
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        if args.reset_dagruns:
            DAG.clear_dags(
                [dag],
                start_date=args.start_date,
                end_date=args.end_date,
                confirm_prompt=not args.yes,
                include_subdags=True,
                dag_run_state=State.NONE,
            )

        dag.run(
            start_date=args.start_date,
            end_date=args.end_date,
            mark_success=args.mark_success,
            local=args.local,
            donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
            ignore_first_depends_on_past=args.ignore_first_depends_on_past,
            ignore_task_deps=args.ignore_dependencies,
            pool=args.pool,
            delay_on_limit_secs=args.delay_on_limit,
            verbose=args.verbose,
            conf=run_conf,
            rerun_failed_tasks=args.rerun_failed_tasks,
            run_backwards=args.run_backwards,
        )
    def close(self):
        """
        Close and upload local log file to remote storage S3.
        """
        # When application exit, system shuts down all handlers by
        # calling close method. Here we check if logger is already
        # closed to prevent uploading the log to remote storage multiple
        # times when `logging.shutdown` is called.
        super(QDSTaskHandler, self).close()
        if self.closed or not self.log_relative_path:
            self.closed = True
            return

        local_loc = os.path.join(self.local_base, self.log_relative_path)
        remote_loc = os.path.join(self.remote_base, self.log_relative_path)
        provider = "kubernetes" if conf.getboolean(
            "qubole", "K8S_RUNTIME") else os.environ.get('PROVIDER', "")
        if os.path.exists(local_loc):
            if provider == "kubernetes":
                execute_cmd = ['aws', 's3', 'cp']
                execute_cmd.extend([local_loc, remote_loc])
            elif provider == 'azure':
                # We have to do this because azure does not support colon in the file path
                self.log_relative_path = self.log_relative_path.replace(
                    ':', '.')
                execute_cmd = [
                    '/usr/lib/hadoop2/bin/hadoop', 'dfs',
                    '-Dfs.azure.account.key.{}.blob.core.windows.net={}'.
                    format(os.environ.get("AZURE_STORAGE_ACCOUNT"),
                           os.environ.get("AZURE_STORAGE_ACCESS_KEY"))
                ]
                execute_cmd.extend([
                    "-copyFromLocal", "-f", local_loc,
                    "{}/{}".format(self.remote_base, self.log_relative_path)
                ])
            elif provider == 'oracle_bmc':
                tmp_command_directory = tempfile.mkdtemp()
                self.log_relative_path = self.log_relative_path.replace(
                    ':', '.')
                config = json.load(open('/root/config.json', "r+"))
                execute_cmd = '/usr/lib/hadoop2/bin/hadoop' + ' dfs' + ' -Dfs.oraclebmc.client.auth.tenantId={}'.format(
                    config['cloud_config']['storage_tenant_id'])
                execute_cmd += ' -Dfs.oraclebmc.client.auth.userId={}'.format(
                    config['cloud_config']['compute_user_id'])
                execute_cmd += ' -Dfs.oraclebmc.client.auth.fingerprint={}'.format(
                    config['cloud_config']['storage_key_finger_print'])
                execute_cmd += ' -Dfs.oraclebmc.client.auth.pemfilecontent="{}"'.format(
                    config['cloud_config']['storage_api_private_rsa_key'])
                execute_cmd += " -copyFromLocal" + " -f" + " " + local_loc + " {}/{}".format(
                    self.remote_base, self.log_relative_path)
                command_file = open(tmp_command_directory + "/command.sh",
                                    "w+")
                command_file.write(execute_cmd)
                command_file.close()
                os.system("source {}/command.sh".format(tmp_command_directory))
            elif provider == "gcp":
                remote_loc = os.path.join(
                    self.remote_base, self.log_relative_path.replace(':', '.'))
                execute_cmd = ["gsutil", "-m", "cp", "-R"]
                execute_cmd.extend([local_loc, remote_loc])
            else:
                execute_cmd = [
                    's3cmd', 'put', '--recursive', '-c',
                    '/usr/lib/hustler/s3cfg'
                ]
                execute_cmd.extend([local_loc, remote_loc])

            if provider[:6] != "oracle":
                process = subprocess.Popen(execute_cmd,
                                           stdout=subprocess.PIPE,
                                           stderr=subprocess.PIPE)

        self.closed = True
    def _read(self, ti, try_number, metadata=None):
        """
        Read logs of given task instance and try_number from S3 remote storage.
        If failed, read the log from task instance host machine.
        :param ti: task instance object
        :param try_number: task instance try_number to read logs from
        """
        # Explicitly getting log relative path is necessary as the given
        # task instance might be different than task instance passed in
        # in set_context method.

        # fetch the logs first either from local file system or worker http endpoint
        log, condition_map = super(QDSTaskHandler,
                                   self)._read(ti, try_number, metadata)
        if not condition_map.get("error_while_fetch", True):
            return log, condition_map

        # fetch from remote storage if not avaialbe on local
        log_relative_path = self._render_filename(ti, try_number)
        remote_loc = os.path.join(self.remote_base, log_relative_path)
        tmp_file = tempfile.NamedTemporaryFile()
        provider = "kubernetes" if conf.getboolean(
            "qubole", "K8S_RUNTIME") else os.environ.get('PROVIDER', "")

        if provider == "kubernetes":
            execute_cmd = ['aws', 's3', 'cp']
            execute_cmd.extend([remote_loc, tmp_file.name])
        elif provider == 'azure':
            # We have to do this because azure does not support colon in the file path
            log_relative_path = log_relative_path.replace(':', '.')
            execute_cmd = [
                '/usr/lib/hadoop2/bin/hadoop', 'dfs',
                '-Dfs.azure.account.key.{}.blob.core.windows.net={}'.format(
                    os.environ.get("AZURE_STORAGE_ACCOUNT"),
                    os.environ.get("AZURE_STORAGE_ACCESS_KEY"))
            ]
            execute_cmd.extend([
                "-copyToLocal", "-f",
                "{}/{}".format(conf.get("core", "remote_base_log_folder"),
                               log_relative_path), tmp_file.name
            ])
        elif provider == 'oracle_bmc':
            tmp_command_directory = tempfile.mkdtemp()
            log_relative_path = log_relative_path.replace(':', '.')
            config = json.load(open('/root/config.json', "r+"))
            execute_cmd = '/usr/lib/hadoop2/bin/hadoop' + ' dfs' + ' -Dfs.oraclebmc.client.auth.tenantId={}'.format(
                config['cloud_config']['storage_tenant_id'])
            execute_cmd += ' -Dfs.oraclebmc.client.auth.userId={}'.format(
                config['cloud_config']['compute_user_id'])
            execute_cmd += ' -Dfs.oraclebmc.client.auth.fingerprint={}'.format(
                config['cloud_config']['storage_key_finger_print'])
            execute_cmd += ' -Dfs.oraclebmc.client.auth.pemfilecontent="{}"'.format(
                config['cloud_config']['storage_api_private_rsa_key'])
            execute_cmd += " -copyToLocal" + " -f" + " {}/{}".format(
                conf.get("core", "remote_base_log_folder"),
                log_relative_path) + " " + tmp_file.name
            command_file = open(tmp_command_directory + "/command.sh", "w+")
            command_file.write(execute_cmd)
            command_file.close()
            os.system("source {}/command.sh".format(tmp_command_directory))
        elif provider == 'gcp':
            remote_loc = os.path.join(self.remote_base,
                                      self.log_relative_path.replace(':', '.'))
            execute_cmd = ["gsutil", "-m", "cp", "-R"]
            execute_cmd.extend([remote_loc, tmp_file.name])
        else:
            execute_cmd = ['s3cmd', 'get', '-c', '/usr/lib/hustler/s3cfg']
            execute_cmd.extend([remote_loc, tmp_file.name, '--force'])

        if provider[:6] != "oracle":
            process = subprocess.Popen(execute_cmd,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE)
            process.wait()
        log += '*** Fetching log from Remote: {}\n'.format(remote_loc)
        log += ('*** Note: Remote logs are only available once '
                'tasks have completed.\n')
        try:
            with open(tmp_file.name) as f:
                log += "".join(f.readlines())
        except Exception as e:
            log = ''

        return log, {'end_of_log': True}
Beispiel #21
0
class DagBag(BaseDagBag, LoggingMixin):
    """
    A dagbag is a collection of dags, parsed out of a folder tree and has high
    level configuration settings, like what database to use as a backend and
    what executor to use to fire off tasks. This makes it easier to run
    distinct environments for say production and development, tests, or for
    different teams or security profiles. What would have been system level
    settings are now dagbag level so that one system can run multiple,
    independent settings sets.

    :param dag_folder: the folder to scan to find DAGs
    :type dag_folder: unicode
    :param executor: the executor to use when executing task instances
        in this DagBag
    :param include_examples: whether to include the examples that ship
        with airflow or not
    :type include_examples: bool
    :param has_logged: an instance boolean that gets flipped from False to True after a
        file has been skipped. This is to prevent overloading the user with logging
        messages about skipped files. Therefore only once per DagBag is a file logged
        being skipped.
    :param store_serialized_dags: Read DAGs from DB if store_serialized_dags is ``True``.
        If ``False`` DAGs are read from python files.
    :type store_serialized_dags: bool
    """

    # static class variables to detetct dag cycle
    CYCLE_NEW = 0
    CYCLE_IN_PROGRESS = 1
    CYCLE_DONE = 2
    DAGBAG_IMPORT_TIMEOUT = conf.getint('core', 'DAGBAG_IMPORT_TIMEOUT')
    UNIT_TEST_MODE = conf.getboolean('core', 'UNIT_TEST_MODE')
    SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

    def __init__(
            self,
            dag_folder=None,
            executor=None,
            include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
            safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
            store_serialized_dags=False,
    ):

        # do not use default arg in signature, to fix import cycle on plugin load
        if executor is None:
            from airflow.executors.executor_loader import ExecutorLoader
            executor = ExecutorLoader.get_default_executor()
        dag_folder = dag_folder or settings.DAGS_FOLDER
        self.dag_folder = dag_folder
        self.dags = {}
        # the file's last modified timestamp when we last read it
        self.file_last_changed = {}
        self.executor = executor
        self.import_errors = {}
        self.has_logged = False
        self.store_serialized_dags = store_serialized_dags

        self.collect_dags(
            dag_folder=dag_folder,
            include_examples=include_examples,
            safe_mode=safe_mode)

    def size(self):
        """
        :return: the amount of dags contained in this dagbag
        """
        return len(self.dags)

    @property
    def dag_ids(self):
        return self.dags.keys()

    def get_dag(self, dag_id, from_file_only=False):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        :param from_file_only: returns a DAG loaded from file.
        :type from_file_only: bool
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        # Only read DAGs from DB if this dagbag is store_serialized_dags.
        # from_file_only is an exception, currently it is for renderring templates
        # in UI only. Because functions are gone in serialized DAGs, DAGs must be
        # imported from files.
        # FIXME: this exception should be removed in future, then webserver can be
        # decoupled from DAG files.
        if self.store_serialized_dags and not from_file_only:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                row = SerializedDagModel.get(dag_id)
                if not row:
                    return None

                dag = row.dag
                for subdag in dag.subdags:
                    self.dags[subdag.dag_id] = subdag
                self.dags[dag.dag_id] = dag

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # Needs to load from file for a store_serialized_dags dagbag.
        enforce_from_file = False
        if self.store_serialized_dags and dag is not None:
            from airflow.serialization.serialized_objects import SerializedDAG
            enforce_from_file = isinstance(dag, SerializedDAG)

        # If the dag corresponding to root_dag_id is absent or expired
        orm_dag = DagModel.get_current(root_dag_id)
        if (orm_dag and (
                root_dag_id not in self.dags or
                (
                    orm_dag.last_expired and
                    dag.last_loaded < orm_dag.last_expired
                )
        )) or enforce_from_file:
            # Reprocess source file
            found_dags = self.process_file(
                filepath=correct_maybe_zipped(orm_dag.fileloc), only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [found_dag.dag_id for found_dag in found_dags]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)

    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        from airflow.models.dag import DAG  # Avoid circular import

        found_dags = []

        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return found_dags

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.file_last_changed \
                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
                return found_dags

        except Exception as e:
            self.log.exception(e)
            return found_dags

        mods = []
        is_zipfile = zipfile.is_zipfile(filepath)
        if not is_zipfile:
            if safe_mode:
                with open(filepath, 'rb') as file:
                    content = file.read()
                    if not all([s in content for s in (b'DAG', b'airflow')]):
                        self.file_last_changed[filepath] = file_last_changed_on_disk
                        # Don't want to spam user with skip messages
                        if not self.has_logged:
                            self.has_logged = True
                            self.log.info(
                                "File %s assumed to contain no DAGs. Skipping.",
                                filepath)
                        return found_dags

            self.log.debug("Importing %s", filepath)
            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
            mod_name = ('unusual_prefix_' +
                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
                        '_' + org_mod_name)

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            with timeout(self.DAGBAG_IMPORT_TIMEOUT):
                try:
                    m = imp.load_source(mod_name, filepath)
                    mods.append(m)
                except Exception as e:
                    self.log.exception("Failed to import: %s", filepath)
                    self.import_errors[filepath] = str(e)
                    self.file_last_changed[filepath] = file_last_changed_on_disk

        else:
            zip_file = zipfile.ZipFile(filepath)
            for mod in zip_file.infolist():
                head, _ = os.path.split(mod.filename)
                mod_name, ext = os.path.splitext(mod.filename)
                if not head and (ext == '.py' or ext == '.pyc'):
                    if mod_name == '__init__':
                        self.log.warning("Found __init__.%s at root of %s", ext, filepath)
                    if safe_mode:
                        with zip_file.open(mod.filename) as zf:
                            self.log.debug("Reading %s from %s", mod.filename, filepath)
                            content = zf.read()
                            if not all([s in content for s in (b'DAG', b'airflow')]):
                                self.file_last_changed[filepath] = (
                                    file_last_changed_on_disk)
                                # todo: create ignore list
                                # Don't want to spam user with skip messages
                                if not self.has_logged:
                                    self.has_logged = True
                                    self.log.info(
                                        "File %s assumed to contain no DAGs. Skipping.",
                                        filepath)

                    if mod_name in sys.modules:
                        del sys.modules[mod_name]

                    try:
                        sys.path.insert(0, filepath)
                        m = importlib.import_module(mod_name)
                        mods.append(m)
                    except Exception as e:
                        self.log.exception("Failed to import: %s", filepath)
                        self.import_errors[filepath] = str(e)
                        self.file_last_changed[filepath] = file_last_changed_on_disk

        for m in mods:
            for dag in list(m.__dict__.values()):
                if isinstance(dag, DAG):
                    if not dag.full_filepath:
                        dag.full_filepath = filepath
                        if dag.fileloc != filepath and not is_zipfile:
                            dag.fileloc = filepath
                    try:
                        dag.is_subdag = False
                        self.bag_dag(dag, parent_dag=dag, root_dag=dag)
                        if isinstance(dag._schedule_interval, str):
                            croniter(dag._schedule_interval)
                        found_dags.append(dag)
                        found_dags += dag.subdags
                    except (CroniterBadCronError,
                            CroniterBadDateError,
                            CroniterNotAlphaError) as cron_e:
                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                        self.import_errors[dag.full_filepath] = \
                            "Invalid Cron expression: " + str(cron_e)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk
                    except AirflowDagCycleException as cycle_exception:
                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                        self.import_errors[dag.full_filepath] = str(cycle_exception)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags

    @provide_session
    def kill_zombies(self, zombies, session=None):
        """
        Fail given zombie tasks, which are tasks that haven't
        had a heartbeat for too long, in the current DagBag.

        :param zombies: zombie task instances to kill.
        :param session: DB session.
        """
        from airflow.models.taskinstance import TaskInstance  # Avoid circular import

        for zombie in zombies:
            if zombie.dag_id in self.dags:
                dag = self.dags[zombie.dag_id]
                if zombie.task_id in dag.task_ids:
                    task = dag.get_task(zombie.task_id)
                    ti = TaskInstance(task, zombie.execution_date)
                    # Get properties needed for failure handling from SimpleTaskInstance.
                    ti.start_date = zombie.start_date
                    ti.end_date = zombie.end_date
                    ti.try_number = zombie.try_number
                    ti.state = zombie.state
                    ti.test_mode = self.UNIT_TEST_MODE
                    ti.handle_failure("{} detected as zombie".format(ti),
                                      ti.test_mode, ti.get_template_context())
                    self.log.info('Marked zombie job %s as %s', ti, ti.state)
                    Stats.incr('zombies_killed')
        session.commit()

    def bag_dag(self, dag, parent_dag, root_dag):
        """
        Adds the DAG into the bag, recurses into sub dags.
        Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
        """

        dag.test_cycle()  # throws if a task cycle is found

        dag.resolve_template_files()
        dag.last_loaded = timezone.utcnow()

        for task in dag.tasks:
            settings.policy(task)

        subdags = dag.subdags

        try:
            for subdag in subdags:
                subdag.full_filepath = dag.full_filepath
                subdag.parent_dag = dag
                subdag.is_subdag = True
                self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag)

            self.dags[dag.dag_id] = dag
            self.log.debug('Loaded DAG %s', dag)
        except AirflowDagCycleException as cycle_exception:
            # There was an error in bagging the dag. Remove it from the list of dags
            self.log.exception('Exception bagging dag: %s', dag.dag_id)
            # Only necessary at the root level since DAG.subdags automatically
            # performs DFS to search through all subdags
            if dag == root_dag:
                for subdag in subdags:
                    if subdag.dag_id in self.dags:
                        del self.dags[subdag.dag_id]
            raise cycle_exception

    def collect_dags(
            self,
            dag_folder=None,
            only_if_updated=True,
            include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
            safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        if self.store_serialized_dags:
            return

        self.log.info("Filling up the DagBag from %s", dag_folder)
        start_dttm = timezone.utcnow()
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []
        FileLoadStat = namedtuple(
            'FileLoadStat', "file duration dag_num task_num dags")

        from airflow.utils.file import correct_maybe_zipped, list_py_file_paths
        dag_folder = correct_maybe_zipped(dag_folder)
        for filepath in list_py_file_paths(dag_folder, safe_mode=safe_mode,
                                           include_examples=include_examples):
            try:
                ts = timezone.utcnow()
                found_dags = self.process_file(
                    filepath, only_if_updated=only_if_updated,
                    safe_mode=safe_mode)
                dag_ids = [dag.dag_id for dag in found_dags]
                dag_id_names = str(dag_ids)

                td = timezone.utcnow() - ts
                stats.append(FileLoadStat(
                    filepath.replace(settings.DAGS_FOLDER, ''),
                    td,
                    len(found_dags),
                    sum([len(dag.tasks) for dag in found_dags]),
                    dag_id_names,
                ))
            except Exception as e:
                self.log.exception(e)
        Stats.gauge(
            'collect_dags', (timezone.utcnow() - start_dttm).total_seconds(), 1)
        Stats.gauge('dagbag_size', len(self.dags), 1)
        Stats.gauge('dagbag_import_errors', len(self.import_errors), 1)
        self.dagbag_stats = sorted(
            stats, key=lambda x: x.duration, reverse=True)
        for file_stat in self.dagbag_stats:
            # file_stat.file similar format: /subdir/dag_name.py
            # TODO: Remove for Airflow 2.0
            filename = file_stat.file.split('/')[-1].replace('.py', '')
            Stats.timing('dag.loading-duration.{}'.
                         format(filename),
                         file_stat.duration)

    def collect_dags_from_db(self):
        """Collects DAGs from database."""
        from airflow.models.serialized_dag import SerializedDagModel
        start_dttm = timezone.utcnow()
        self.log.info("Filling up the DagBag from database")

        # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
        # from the table by the scheduler job.
        self.dags = SerializedDagModel.read_all_dags()

        # Adds subdags.
        # DAG post-processing steps such as self.bag_dag and croniter are not needed as
        # they are done by scheduler before serialization.
        subdags = {}
        for dag in self.dags.values():
            for subdag in dag.subdags:
                subdags[subdag.dag_id] = subdag
        self.dags.update(subdags)

        Stats.timing('collect_db_dags', timezone.utcnow() - start_dttm)

    def dagbag_report(self):
        """Prints a report around DagBag loading stats"""
        report = textwrap.dedent("""\n
        -------------------------------------------------------------------
        DagBag loading stats for {dag_folder}
        -------------------------------------------------------------------
        Number of DAGs: {dag_num}
        Total task number: {task_num}
        DagBag parsing time: {duration}
        {table}
        """)
        stats = self.dagbag_stats
        return report.format(
            dag_folder=self.dag_folder,
            duration=sum([o.duration for o in stats], timedelta()).total_seconds(),
            dag_num=sum([o.dag_num for o in stats]),
            task_num=sum([o.task_num for o in stats]),
            table=pprinttable(stats),
        )
Beispiel #22
0
 def apply_caching(response):
     _x_frame_enabled = conf.getboolean('webserver', 'X_FRAME_ENABLED', fallback=True)
     if not _x_frame_enabled:
         response.headers["X-Frame-Options"] = "DENY"
     return response
Beispiel #23
0
def create_app(config=None, testing=False):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        minutes=settings.get_session_lifetime_config())
    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = conf.get(section="webserver",
                                            key="instance_name",
                                            fallback="Airflow")
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get(
        'core', 'SQL_ALCHEMY_CONN')
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')

    cookie_samesite_config = conf.get('webserver', 'COOKIE_SAMESITE')
    if cookie_samesite_config == "":
        warnings.warn(
            "Old deprecated value found for `cookie_samesite` option in `[webserver]` section. "
            "Using `Lax` instead. Change the value to `Lax` in airflow.cfg to remove this warning.",
            DeprecationWarning,
        )
        cookie_samesite_config = "Lax"
    flask_app.config['SESSION_COOKIE_SAMESITE'] = cookie_samesite_config

    if config:
        flask_app.config.from_mapping(config)

    if 'SQLALCHEMY_ENGINE_OPTIONS' not in flask_app.config:
        flask_app.config[
            'SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    csrf.init_app(flask_app)

    init_wsgi_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    init_robots(flask_app)

    cache_config = {
        'CACHE_TYPE': 'flask_caching.backends.filesystem',
        'CACHE_DIR': gettempdir()
    }
    Cache(app=flask_app, config=cache_config)

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_connection_form()
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)
        init_airflow_session_interface(flask_app)
    return flask_app
Beispiel #24
0
def create_app(config=None, testing=False, app_name="Airflow"):
    global app, appbuilder
    app = Flask(__name__)
    app.secret_key = conf.get('webserver', 'SECRET_KEY')

    session_lifetime_days = conf.getint('webserver', 'SESSION_LIFETIME_DAYS', fallback=30)
    app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=session_lifetime_days)

    app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    app.config['APP_NAME'] = app_name
    app.config['TESTING'] = testing
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    app.config['SESSION_COOKIE_HTTPONLY'] = True
    app.config['SESSION_COOKIE_SECURE'] = conf.getboolean('webserver', 'COOKIE_SECURE')
    app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver', 'COOKIE_SAMESITE')

    if config:
        app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    app.json_encoder = AirflowJsonEncoder

    csrf.init_app(app)
    db = SQLA()
    db.session = settings.Session
    db.init_app(app)

    from airflow import api
    api.load_auth()
    api.API_AUTH.api_auth.init_app(app)

    Cache(app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})

    from airflow.www.blueprints import routes
    app.register_blueprint(routes)

    configure_logging()
    configure_manifest_files(app)

    with app.app_context():
        from airflow.www.security import AirflowSecurityManager
        security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \
            AirflowSecurityManager

        if not issubclass(security_manager_class, AirflowSecurityManager):
            raise Exception(
                """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager,
                 not FAB's security manager.""")

        class AirflowAppBuilder(AppBuilder):

            def _check_and_init(self, baseview):
                if hasattr(baseview, 'datamodel'):
                    # Delete sessions if initiated previously to limit side effects. We want to use
                    # the current session in the current application.
                    baseview.datamodel.session = None
                return super()._check_and_init(baseview)

        appbuilder = AirflowAppBuilder(
            app=app,
            session=settings.Session,
            security_manager_class=security_manager_class,
            base_template='airflow/master.html',
            update_perms=conf.getboolean('webserver', 'UPDATE_FAB_PERMS'))

        def init_views(appbuilder):
            from airflow.www import views
            # Remove the session from scoped_session registry to avoid
            # reusing a session with a disconnected connection
            appbuilder.session.remove()
            appbuilder.add_view_no_menu(views.Airflow())
            appbuilder.add_view_no_menu(views.DagModelView())
            appbuilder.add_view(views.DagRunModelView,
                                "DAG Runs",
                                category="Browse",
                                category_icon="fa-globe")
            appbuilder.add_view(views.JobModelView,
                                "Jobs",
                                category="Browse")
            appbuilder.add_view(views.LogModelView,
                                "Logs",
                                category="Browse")
            appbuilder.add_view(views.SlaMissModelView,
                                "SLA Misses",
                                category="Browse")
            appbuilder.add_view(views.TaskInstanceModelView,
                                "Task Instances",
                                category="Browse")
            appbuilder.add_view(views.ConfigurationView,
                                "Configurations",
                                category="Admin",
                                category_icon="fa-user")
            appbuilder.add_view(views.ConnectionModelView,
                                "Connections",
                                category="Admin")
            appbuilder.add_view(views.PoolModelView,
                                "Pools",
                                category="Admin")
            appbuilder.add_view(views.VariableModelView,
                                "Variables",
                                category="Admin")
            appbuilder.add_view(views.XComModelView,
                                "XComs",
                                category="Admin")

            if "dev" in version.version:
                airflow_doc_site = "https://airflow.readthedocs.io/en/latest"
            else:
                airflow_doc_site = 'https://airflow.apache.org/docs/{}'.format(version.version)

            appbuilder.add_link("Website",
                                href='https://airflow.apache.org',
                                category="Docs",
                                category_icon="fa-globe")
            appbuilder.add_link("Documentation",
                                href=airflow_doc_site,
                                category="Docs",
                                category_icon="fa-cube")
            appbuilder.add_link("GitHub",
                                href='https://github.com/apache/airflow',
                                category="Docs")
            appbuilder.add_view(views.VersionView,
                                'Version',
                                category='About',
                                category_icon='fa-th')

            def integrate_plugins():
                """Integrate plugins to the context"""
                from airflow import plugins_manager

                plugins_manager.initialize_web_ui_plugins()

                for v in plugins_manager.flask_appbuilder_views:
                    log.debug("Adding view %s", v["name"])
                    appbuilder.add_view(v["view"],
                                        v["name"],
                                        category=v["category"])
                for ml in sorted(plugins_manager.flask_appbuilder_menu_links, key=lambda x: x["name"]):
                    log.debug("Adding menu link %s", ml["name"])
                    appbuilder.add_link(ml["name"],
                                        href=ml["href"],
                                        category=ml["category"],
                                        category_icon=ml["category_icon"])

            integrate_plugins()
            # Garbage collect old permissions/views after they have been modified.
            # Otherwise, when the name of a view or menu is changed, the framework
            # will add the new Views and Menus names to the backend, but will not
            # delete the old ones.

        def init_plugin_blueprints(app):
            from airflow.plugins_manager import flask_blueprints

            for bp in flask_blueprints:
                log.debug("Adding blueprint %s:%s", bp["name"], bp["blueprint"].import_name)
                app.register_blueprint(bp["blueprint"])

        def init_error_handlers(app: Flask):
            from airflow.www import views
            app.register_error_handler(500, views.show_traceback)
            app.register_error_handler(404, views.circles)

        init_views(appbuilder)
        init_plugin_blueprints(app)
        init_error_handlers(app)

        if conf.getboolean('webserver', 'UPDATE_FAB_PERMS'):
            security_manager = appbuilder.sm
            security_manager.sync_roles()

        from airflow.www.api.experimental import endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            import importlib
            importlib.reload(e)

        app.register_blueprint(e.api_experimental, url_prefix='/api/experimental')

        server_timezone = conf.get('core', 'default_timezone')
        if server_timezone == "system":
            server_timezone = pendulum.local_timezone().name
        elif server_timezone == "utc":
            server_timezone = "UTC"

        default_ui_timezone = conf.get('webserver', 'default_ui_timezone')
        if default_ui_timezone == "system":
            default_ui_timezone = pendulum.local_timezone().name
        elif default_ui_timezone == "utc":
            default_ui_timezone = "UTC"
        if not default_ui_timezone:
            default_ui_timezone = server_timezone

        @app.context_processor
        def jinja_globals():  # pylint: disable=unused-variable

            globals = {
                'server_timezone': server_timezone,
                'default_ui_timezone': default_ui_timezone,
                'hostname': socket.getfqdn() if conf.getboolean(
                    'webserver', 'EXPOSE_HOSTNAME', fallback=True) else 'redact',
                'navbar_color': conf.get(
                    'webserver', 'NAVBAR_COLOR'),
                'log_fetch_delay_sec': conf.getint(
                    'webserver', 'log_fetch_delay_sec', fallback=2),
                'log_auto_tailing_offset': conf.getint(
                    'webserver', 'log_auto_tailing_offset', fallback=30),
                'log_animation_speed': conf.getint(
                    'webserver', 'log_animation_speed', fallback=1000)
            }

            if 'analytics_tool' in conf.getsection('webserver'):
                globals.update({
                    'analytics_tool': conf.get('webserver', 'ANALYTICS_TOOL'),
                    'analytics_id': conf.get('webserver', 'ANALYTICS_ID')
                })

            return globals

        @app.before_request
        def before_request():
            _force_log_out_after = conf.getint('webserver', 'FORCE_LOG_OUT_AFTER', fallback=0)
            if _force_log_out_after > 0:
                flask.session.permanent = True
                app.permanent_session_lifetime = datetime.timedelta(minutes=_force_log_out_after)
                flask.session.modified = True
                flask.g.user = flask_login.current_user

        @app.after_request
        def apply_caching(response):
            _x_frame_enabled = conf.getboolean('webserver', 'X_FRAME_ENABLED', fallback=True)
            if not _x_frame_enabled:
                response.headers["X-Frame-Options"] = "DENY"
            return response

        @app.before_request
        def make_session_permanent():
            flask_session.permanent = True

    return app, appbuilder
Beispiel #25
0
from airflow.configuration import conf, AirflowConfigException
from airflow.models import DAG
from flask.ext.admin import BaseView
from importlib import import_module
from airflow.utils import AirflowException

DAGS_FOLDER = os.path.expanduser(conf.get("core", "DAGS_FOLDER"))
if DAGS_FOLDER not in sys.path:
    sys.path.append(DAGS_FOLDER)

auth_backend = "airflow.default_login"
try:
    auth_backend = conf.get("webserver", "auth_backend")
except AirflowConfigException:
    if conf.getboolean("webserver", "AUTHENTICATE"):
        logging.warning(
            "auth_backend not found in webserver config reverting to *deprecated*"
            " behavior of importing airflow_login"
        )
        auth_backend = "airflow_login"

try:
    login = import_module(auth_backend)
except ImportError:
    logging.critical(
        "Cannot import authentication module %s. "
        "Please correct your authentication backend or disable authentication",
        auth_backend,
    )
    if conf.getboolean("webserver", "AUTHENTICATE"):
def init_jinja_globals(app):
    """Add extra globals variable to Jinja context"""
    server_timezone = conf.get('core', 'default_timezone')
    if server_timezone == "system":
        server_timezone = pendulum.local_timezone().name
    elif server_timezone == "utc":
        server_timezone = "UTC"

    default_ui_timezone = conf.get('webserver', 'default_ui_timezone')
    if default_ui_timezone == "system":
        default_ui_timezone = pendulum.local_timezone().name
    elif default_ui_timezone == "utc":
        default_ui_timezone = "UTC"
    if not default_ui_timezone:
        default_ui_timezone = server_timezone

    expose_hostname = conf.getboolean('webserver',
                                      'EXPOSE_HOSTNAME',
                                      fallback=True)
    hosstname = socket.getfqdn() if expose_hostname else 'redact'

    try:
        airflow_version = airflow.__version__
    except Exception as e:  # pylint: disable=broad-except
        airflow_version = None
        logging.error(e)

    git_version = get_airflow_git_version()

    def prepare_jinja_globals():
        extra_globals = {
            'server_timezone':
            server_timezone,
            'default_ui_timezone':
            default_ui_timezone,
            'hostname':
            hosstname,
            'navbar_color':
            conf.get('webserver', 'NAVBAR_COLOR'),
            'log_fetch_delay_sec':
            conf.getint('webserver', 'log_fetch_delay_sec', fallback=2),
            'log_auto_tailing_offset':
            conf.getint('webserver', 'log_auto_tailing_offset', fallback=30),
            'log_animation_speed':
            conf.getint('webserver', 'log_animation_speed', fallback=1000),
            'state_color_mapping':
            STATE_COLORS,
            'airflow_version':
            airflow_version,
            'git_version':
            git_version,
            'k8s_or_k8scelery_executor':
            IS_K8S_OR_K8SCELERY_EXECUTOR,
        }

        if 'analytics_tool' in conf.getsection('webserver'):
            extra_globals.update({
                'analytics_tool':
                conf.get('webserver', 'ANALYTICS_TOOL'),
                'analytics_id':
                conf.get('webserver', 'ANALYTICS_ID'),
            })

        return extra_globals

    app.context_processor(prepare_jinja_globals)
Beispiel #27
0
    import_local_settings()
    global LOGGING_CLASS_PATH
    LOGGING_CLASS_PATH = configure_logging()
    configure_adapters()
    # The webservers import this file from models.py with the default settings.
    configure_orm()
    configure_action_logging()

    # Ensure we close DB connections at scheduler and gunicon worker terminations
    atexit.register(dispose_orm)


# Const stuff

KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0',
              'LIGHTORANGE': '#FF9933'}

# Used by DAG context_managers
CONTEXT_MANAGER_DAG = None  # type: Optional[airflow.models.dag.DAG]

# If store_serialized_dags is True, scheduler writes serialized DAGs to DB, and webserver
# reads DAGs from DB instead of importing from files.
STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallback=False)

# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_update_interval', fallback=30)
Beispiel #28
0
import pendulum
import sys
from typing import Any

from sqlalchemy import create_engine, exc
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool

from airflow.configuration import conf, AIRFLOW_HOME, WEBSERVER_CONFIG  # NOQA F401
from airflow.contrib.kubernetes.pod import Pod
from airflow.logging_config import configure_logging
from airflow.utils.sqlalchemy import setup_event_handlers

log = logging.getLogger(__name__)

RBAC = conf.getboolean('webserver', 'rbac')

TIMEZONE = pendulum.timezone('UTC')
try:
    tz = conf.get("core", "default_timezone")
    if tz == "system":
        TIMEZONE = pendulum.local_timezone()
    else:
        TIMEZONE = pendulum.timezone(tz)
except Exception:
    pass
log.info("Configured default timezone %s" % TIMEZONE)


class DummyStatsLogger(object):
    @classmethod
Beispiel #29
0

def nulls_first(col, session: Session) -> Dict[str, Any]:
    """
    Adds a nullsfirst construct to the column ordering. Currently only Postgres supports it.
    In MySQL & Sqlite NULL values are considered lower than any non-NULL value, therefore, NULL values
    appear first when the order is ASC (ascending)
    """
    if session.bind.dialect.name == "postgresql":
        return nullsfirst(col)
    else:
        return col


USE_ROW_LEVEL_LOCKING: bool = conf.getboolean('scheduler',
                                              'use_row_level_locking',
                                              fallback=True)


def with_row_locks(query, session: Session, **kwargs):
    """
    Apply with_for_update to an SQLAlchemy query, if row level locking is in use.

    :param query: An SQLAlchemy Query object
    :param session: ORM Session
    :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
    :return: updated query
    """
    dialect = session.bind.dialect

    # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
Beispiel #30
0
        broker_transport_options['visibility_timeout'] = 21600

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer':
    'json',
    'worker_prefetch_multiplier':
    conf.getint('celery', 'worker_prefetch_multiplier'),
    'task_acks_late':
    True,
    'task_default_queue':
    conf.get('operators', 'DEFAULT_QUEUE'),
    'task_default_exchange':
    conf.get('operators', 'DEFAULT_QUEUE'),
    'task_track_started':
    conf.getboolean('celery', 'task_track_started'),
    'broker_url':
    broker_url,
    'broker_transport_options':
    broker_transport_options,
    'result_backend':
    conf.get('celery', 'RESULT_BACKEND'),
    'worker_concurrency':
    conf.getint('celery', 'WORKER_CONCURRENCY'),
    'worker_enable_remote_control':
    conf.getboolean('celery', 'worker_enable_remote_control'),
}

celery_ssl_active = False
try:
    celery_ssl_active = conf.getboolean('celery', 'SSL_ACTIVE')
Beispiel #31
0
def create_app(config=None, session=None, testing=False, app_name="Airflow"):
    global app, appbuilder
    app = Flask(__name__)
    if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
        app.wsgi_app = ProxyFix(app.wsgi_app,
                                num_proxies=None,
                                x_for=1,
                                x_proto=1,
                                x_host=1,
                                x_port=1,
                                x_prefix=1)
    app.secret_key = conf.get('webserver', 'SECRET_KEY')

    app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    app.config['APP_NAME'] = app_name
    app.config['TESTING'] = testing

    app.config['SESSION_COOKIE_HTTPONLY'] = True
    app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver',
                                                     'COOKIE_SAMESITE')

    if config:
        app.config.from_mapping(config)

    csrf.init_app(app)

    db = SQLA(app)

    from airflow import api
    api.load_auth()
    api.API_AUTH.api_auth.init_app(app)

    # flake8: noqa: F841
    cache = Cache(app=app,
                  config={
                      'CACHE_TYPE': 'filesystem',
                      'CACHE_DIR': '/tmp'
                  })

    from airflow.www_rbac.blueprints import routes
    app.register_blueprint(routes)

    configure_logging()
    configure_manifest_files(app)

    with app.app_context():

        from airflow.www_rbac.security import AirflowSecurityManager
        security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \
            AirflowSecurityManager

        if not issubclass(security_manager_class, AirflowSecurityManager):
            raise Exception(
                """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager,
                 not FAB's security manager.""")

        appbuilder = AppBuilder(app,
                                db.session if not session else session,
                                security_manager_class=security_manager_class,
                                base_template='appbuilder/baselayout.html')

        def init_views(appbuilder):
            from airflow.www_rbac import views
            appbuilder.add_view_no_menu(views.Airflow())
            appbuilder.add_view_no_menu(views.DagModelView())
            appbuilder.add_view_no_menu(views.ConfigurationView())
            appbuilder.add_view_no_menu(views.VersionView())
            appbuilder.add_view(views.DagRunModelView,
                                "DAG Runs",
                                category="Browse",
                                category_icon="fa-globe")
            appbuilder.add_view(views.JobModelView, "Jobs", category="Browse")
            appbuilder.add_view(views.LogModelView, "Logs", category="Browse")
            appbuilder.add_view(views.SlaMissModelView,
                                "SLA Misses",
                                category="Browse")
            appbuilder.add_view(views.TaskInstanceModelView,
                                "Task Instances",
                                category="Browse")
            appbuilder.add_link("Configurations",
                                href='/configuration',
                                category="Admin",
                                category_icon="fa-user")
            appbuilder.add_view(views.ConnectionModelView,
                                "Connections",
                                category="Admin")
            appbuilder.add_view(views.PoolModelView, "Pools", category="Admin")
            appbuilder.add_view(views.VariableModelView,
                                "Variables",
                                category="Admin")
            appbuilder.add_view(views.XComModelView, "XComs", category="Admin")
            appbuilder.add_link("Documentation",
                                href='https://airflow.apache.org/',
                                category="Docs",
                                category_icon="fa-cube")
            appbuilder.add_link("GitHub",
                                href='https://github.com/apache/airflow',
                                category="Docs")
            appbuilder.add_link('Version',
                                href='/version',
                                category='About',
                                category_icon='fa-th')

            def integrate_plugins():
                """Integrate plugins to the context"""
                from airflow.plugins_manager import (
                    flask_appbuilder_views, flask_appbuilder_menu_links)

                for v in flask_appbuilder_views:
                    log.debug("Adding view %s", v["name"])
                    appbuilder.add_view(v["view"],
                                        v["name"],
                                        category=v["category"])
                for ml in sorted(flask_appbuilder_menu_links,
                                 key=lambda x: x["name"]):
                    log.debug("Adding menu link %s", ml["name"])
                    appbuilder.add_link(ml["name"],
                                        href=ml["href"],
                                        category=ml["category"],
                                        category_icon=ml["category_icon"])

            integrate_plugins()
            # Garbage collect old permissions/views after they have been modified.
            # Otherwise, when the name of a view or menu is changed, the framework
            # will add the new Views and Menus names to the backend, but will not
            # delete the old ones.

        def init_plugin_blueprints(app):
            from airflow.plugins_manager import flask_blueprints

            for bp in flask_blueprints:
                log.debug("Adding blueprint %s:%s", bp["name"],
                          bp["blueprint"].import_name)
                app.register_blueprint(bp["blueprint"])

        init_views(appbuilder)
        init_plugin_blueprints(app)

        security_manager = appbuilder.sm
        security_manager.sync_roles()

        from airflow.www_rbac.api.experimental import endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            if six.PY2:
                reload(e)  # noqa
            else:
                import importlib
                importlib.reload(e)

        app.register_blueprint(e.api_experimental,
                               url_prefix='/api/experimental')

        @app.context_processor
        def jinja_globals():  # pylint: disable=unused-variable

            globals = {
                'hostname': socket.getfqdn(),
                'navbar_color': conf.get('webserver', 'NAVBAR_COLOR'),
            }

            if 'analytics_tool' in conf.getsection('webserver'):
                globals.update({
                    'analytics_tool':
                    conf.get('webserver', 'ANALYTICS_TOOL'),
                    'analytics_id':
                    conf.get('webserver', 'ANALYTICS_ID')
                })

            return globals

        @app.teardown_appcontext
        def shutdown_session(exception=None):
            settings.Session.remove()

    return app, appbuilder
Beispiel #32
0
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.kube_env_vars = configuration_dict.get('kubernetes_environment_variables', {})
        self.airflow_home = configuration.get(self.core_section, 'airflow_home')
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section, 'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(
            self.kubernetes_section, 'worker_container_tag')
        self.kube_image = '{}:{}'.format(
            self.worker_container_repository, self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy"
        )
        self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
        self.kube_annotations = configuration_dict.get('kubernetes_annotations', {})
        self.delete_worker_pods = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section, 'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section, 'dags_in_image')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(self.kubernetes_section,
                                                    'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section, 'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(self.kubernetes_section,
                                                           'git_ssh_known_hosts_configmap_name')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section, 'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section, 'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(
            self.kubernetes_section, 'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(
            self.kubernetes_section, 'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section, 'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section, 'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section, 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(
            self.kubernetes_section, 'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section, 'airflow_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        self._validate()
Beispiel #33
0
from airflow import jobs
from airflow import models
from airflow.models import State
from airflow import settings
from airflow.configuration import conf
from airflow import utils
from airflow.www import utils as wwwutils

from airflow.www.login import login_manager
import flask_login
from flask_login import login_required

QUERY_LIMIT = 100000
CHART_LIMIT = 200000

AUTHENTICATE = conf.getboolean('core', 'AUTHENTICATE')
if AUTHENTICATE is False:
    login_required = lambda x: x

dagbag = models.DagBag(conf.get('core', 'DAGS_FOLDER'))
utils.pessimistic_connection_handling()

app = Flask(__name__)
app.config['SQLALCHEMY_POOL_RECYCLE'] = 3600

login_manager.init_app(app)
app.secret_key = 'airflowified'

cache = Cache(
    app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})
Beispiel #34
0
__version__ = "1.3.0"

import logging
import os
import sys
from airflow.configuration import conf
from airflow.models import DAG
from flask.ext.admin import BaseView


DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))
if DAGS_FOLDER not in sys.path:
    sys.path.append(DAGS_FOLDER)

from airflow import default_login as login
if conf.getboolean('webserver', 'AUTHENTICATE'):
    try:
        # Environment specific login
        import airflow_login as login
    except ImportError:
        logging.error(
            "authenticate is set to True in airflow.cfg, "
            "but airflow_login failed to import")


class AirflowViewPlugin(BaseView):
    pass

class AirflowMacroPlugin(object):
    def __init__(self, namespace):
        self.namespace = namespace
Beispiel #35
0
from future import standard_library
standard_library.install_aliases()
from builtins import str
from builtins import object
from cgi import escape
from io import BytesIO as IO
import gzip
import functools

from flask import after_this_request, request
from flask_login import current_user
import wtforms
from wtforms.compat import text_type

from airflow.configuration import conf
AUTHENTICATE = conf.getboolean('webserver', 'AUTHENTICATE')


class LoginMixin(object):
    def is_accessible(self):
        return (not AUTHENTICATE or (not current_user.is_anonymous()
                                     and current_user.is_authenticated()))


class SuperUserMixin(object):
    def is_accessible(self):
        return (not AUTHENTICATE or (not current_user.is_anonymous()
                                     and current_user.is_superuser()))


class DataProfilingMixin(object):
Beispiel #36
0
# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL: str = conf.get('logging', 'LOGGING_LEVEL').upper()

# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
FAB_LOG_LEVEL: str = conf.get('logging', 'FAB_LOGGING_LEVEL').upper()

LOG_FORMAT: str = conf.get('logging', 'LOG_FORMAT')

COLORED_LOG_FORMAT: str = conf.get('logging', 'COLORED_LOG_FORMAT')

COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG')

COLORED_FORMATTER_CLASS: str = conf.get('logging', 'COLORED_FORMATTER_CLASS')

BASE_LOG_FOLDER: str = conf.get('logging', 'BASE_LOG_FOLDER')

PROCESSOR_LOG_FOLDER: str = conf.get('scheduler',
                                     'CHILD_PROCESS_LOG_DIRECTORY')

DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get(
    'logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')

FILENAME_TEMPLATE: str = conf.get('logging', 'LOG_FILENAME_TEMPLATE')

PROCESSOR_FILENAME_TEMPLATE: str = conf.get('logging',
                                            'LOG_PROCESSOR_FILENAME_TEMPLATE')
    def start(self):
        self.task_queue = Queue()
        self.result_queue = Queue()
        framework = mesos_pb2.FrameworkInfo()
        framework.user = ""

        if not conf.get("mesos", "MASTER"):
            logging.error("Expecting mesos master URL for mesos executor")
            raise AirflowException("mesos.master not provided for mesos executor")

        master = conf.get("mesos", "MASTER")

        if not conf.get("mesos", "FRAMEWORK_NAME"):
            framework.name = "Airflow"
        else:
            framework.name = conf.get("mesos", "FRAMEWORK_NAME")

        if not conf.get("mesos", "TASK_CPU"):
            task_cpu = 1
        else:
            task_cpu = conf.getint("mesos", "TASK_CPU")

        if not conf.get("mesos", "TASK_MEMORY"):
            task_memory = 256
        else:
            task_memory = conf.getint("mesos", "TASK_MEMORY")

        if conf.getboolean("mesos", "CHECKPOINT"):
            framework.checkpoint = True
        else:
            framework.checkpoint = False

        logging.info(
            "MesosFramework master : %s, name : %s, cpu : %s, mem : %s, checkpoint : %s",
            master,
            framework.name,
            str(task_cpu),
            str(task_memory),
            str(framework.checkpoint),
        )

        implicit_acknowledgements = 1

        if conf.getboolean("mesos", "AUTHENTICATE"):
            if not conf.get("mesos", "DEFAULT_PRINCIPAL"):
                logging.error("Expecting authentication principal in the environment")
                raise AirflowException("mesos.default_principal not provided in authenticated mode")
            if not conf.get("mesos", "DEFAULT_SECRET"):
                logging.error("Expecting authentication secret in the environment")
                raise AirflowException("mesos.default_secret not provided in authenticated mode")

            credential = mesos_pb2.Credential()
            credential.principal = conf.get("mesos", "DEFAULT_PRINCIPAL")
            credential.secret = conf.get("mesos", "DEFAULT_SECRET")

            framework.principal = credential.principal

            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory),
                framework,
                master,
                implicit_acknowledgements,
                credential,
            )
        else:
            framework.principal = "Airflow"
            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory),
                framework,
                master,
                implicit_acknowledgements,
            )

        self.mesos_driver = driver
        self.mesos_driver.start()
Beispiel #38
0
# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL = 'INFO'  # conf.get('core', 'LOGGING_LEVEL').upper()

# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
FAB_LOG_LEVEL = conf.get('core', 'FAB_LOGGING_LEVEL').upper()

LOG_FORMAT = conf.get('core', 'LOG_FORMAT')

COLORED_LOG_FORMAT = conf.get('core', 'COLORED_LOG_FORMAT')

COLORED_LOG = conf.getboolean('core', 'COLORED_CONSOLE_LOG')

COLORED_FORMATTER_CLASS = conf.get('core', 'COLORED_FORMATTER_CLASS')

BASE_LOG_FOLDER = '/opt/airflow/logs'
# BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER')
#
PROCESSOR_LOG_FOLDER = '/opt/airflow/logs'
# PROCESSOR_LOG_FOLDER = conf.get('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY')

DAG_PROCESSOR_MANAGER_LOG_LOCATION = \
    conf.get('core', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')

FILENAME_TEMPLATE = conf.get('core', 'LOG_FILENAME_TEMPLATE')

PROCESSOR_FILENAME_TEMPLATE = conf.get('core',
Beispiel #39
0
def get_kube_client(in_cluster=conf.getboolean('kubernetes', 'in_cluster'),
                    cluster_context=None):
    return _load_kube_config(in_cluster, cluster_context)
Beispiel #40
0
 def read_store_serialized_dags():
     from airflow.configuration import conf
     return conf.getboolean('core', 'store_serialized_dags')
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.airflow_home = configuration.get(self.core_section, 'airflow_home')
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section, 'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(
            self.kubernetes_section, 'worker_container_tag')
        self.worker_dags_folder = configuration.get(
            self.kubernetes_section, 'worker_dags_folder')
        self.kube_image = '{}:{}'.format(
            self.worker_container_repository, self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy"
        )
        self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
        self.delete_worker_pods = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods')

        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section, 'image_pull_secrets')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')

        # Optionally a user may supply a `git_user` and `git_password` for private
        # repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section, 'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section, 'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(
            self.kubernetes_section, 'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(
            self.kubernetes_section, 'logs_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section, 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(
            self.kubernetes_section, 'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section, 'airflow_configmap')

        self._validate()
Beispiel #42
0
    def collect_dags(self,
                     dag_folder=None,
                     only_if_updated=True,
                     include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
                     include_smart_sensor=conf.getboolean(
                         'smart_sensor', 'USE_SMART_SENSOR'),
                     safe_mode=conf.getboolean('core',
                                               'DAG_DISCOVERY_SAFE_MODE')):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        if self.read_dags_from_db:
            return

        self.log.info("Filling up the DagBag from %s", dag_folder)
        start_dttm = timezone.utcnow()
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []

        dag_folder = correct_maybe_zipped(dag_folder)
        for filepath in list_py_file_paths(
                dag_folder,
                safe_mode=safe_mode,
                include_examples=include_examples,
                include_smart_sensor=include_smart_sensor):
            try:
                file_parse_start_dttm = timezone.utcnow()
                found_dags = self.process_file(filepath,
                                               only_if_updated=only_if_updated,
                                               safe_mode=safe_mode)

                file_parse_end_dttm = timezone.utcnow()
                stats.append(
                    FileLoadStat(
                        file=filepath.replace(settings.DAGS_FOLDER, ''),
                        duration=file_parse_end_dttm - file_parse_start_dttm,
                        dag_num=len(found_dags),
                        task_num=sum([len(dag.tasks) for dag in found_dags]),
                        dags=str([dag.dag_id for dag in found_dags]),
                    ))
            except Exception as e:  # pylint: disable=broad-except
                self.log.exception(e)

        end_dttm = timezone.utcnow()
        durations = (end_dttm - start_dttm).total_seconds()
        Stats.gauge('collect_dags', durations, 1)
        Stats.gauge('dagbag_size', len(self.dags), 1)
        Stats.gauge('dagbag_import_errors', len(self.import_errors), 1)
        self.dagbag_stats = sorted(stats,
                                   key=lambda x: x.duration,
                                   reverse=True)
        for file_stat in self.dagbag_stats:
            # file_stat.file similar format: /subdir/dag_name.py
            # TODO: Remove for Airflow 2.0
            filename = file_stat.file.split('/')[-1].replace('.py', '')
            Stats.timing('dag.loading-duration.{}'.format(filename),
                         file_stat.duration)
Beispiel #43
0
def renew_from_kt(principal: Optional[str],
                  keytab: str,
                  exit_on_fail: bool = True):
    """
    Renew kerberos token from keytab

    :param principal: principal
    :param keytab: keytab file
    :return: None
    """
    # The config is specified in seconds. But we ask for that same amount in
    # minutes to give ourselves a large renewal buffer.
    renewal_lifetime = f"{conf.getint('kerberos', 'reinit_frequency')}m"

    cmd_principal = principal or conf.get_mandatory_value(
        'kerberos', 'principal').replace("_HOST", socket.getfqdn())

    if conf.getboolean('kerberos', 'forwardable'):
        forwardable = '-f'
    else:
        forwardable = '-F'

    if conf.getboolean('kerberos', 'include_ip'):
        include_ip = '-a'
    else:
        include_ip = '-A'

    cmdv: List[str] = [
        conf.get_mandatory_value('kerberos', 'kinit_path'),
        forwardable,
        include_ip,
        "-r",
        renewal_lifetime,
        "-k",  # host ticket
        "-t",
        keytab,  # specify keytab
        "-c",
        conf.get_mandatory_value('kerberos',
                                 'ccache'),  # specify credentials cache
        cmd_principal,
    ]
    log.info("Re-initialising kerberos from keytab: %s",
             " ".join(shlex.quote(f) for f in cmdv))

    with subprocess.Popen(
            cmdv,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            close_fds=True,
            bufsize=-1,
            universal_newlines=True,
    ) as subp:
        subp.wait()
        if subp.returncode != 0:
            log.error(
                "Couldn't reinit from keytab! `kinit' exited with %s.\n%s\n%s",
                subp.returncode,
                "\n".join(subp.stdout.readlines() if subp.stdout else []),
                "\n".join(subp.stderr.readlines() if subp.stderr else []),
            )
            if exit_on_fail:
                sys.exit(subp.returncode)
            else:
                return subp.returncode

    global NEED_KRB181_WORKAROUND
    if NEED_KRB181_WORKAROUND is None:
        NEED_KRB181_WORKAROUND = detect_conf_var()
    if NEED_KRB181_WORKAROUND:
        # (From: HUE-640). Kerberos clock have seconds level granularity. Make sure we
        # renew the ticket after the initial valid time.
        time.sleep(1.5)
        ret = perform_krb181_workaround(cmd_principal)
        if exit_on_fail and ret != 0:
            sys.exit(ret)
        else:
            return ret
    return 0
Beispiel #44
0
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.airflow_home = configuration.get(self.core_section,
                                              'airflow_home')
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section,
                                                'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(self.kubernetes_section,
                                                      'worker_container_tag')
        self.kube_image = '{}:{}'.format(self.worker_container_repository,
                                         self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy")
        self.kube_node_selectors = configuration_dict.get(
            'kubernetes_node_selectors', {})
        self.kube_annotations = configuration_dict.get(
            'kubernetes_annotations', {})
        self.delete_worker_pods = conf.getboolean(self.kubernetes_section,
                                                  'delete_worker_pods')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section,
                                           'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section,
                                             'dags_in_image')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(
            self.kubernetes_section, 'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section,
                                                'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(
            self.kubernetes_section, 'git_ssh_known_hosts_configmap_name')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section,
                                          'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section,
                                          'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(self.kubernetes_section,
                                            'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(self.kubernetes_section,
                                            'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section,
                                         'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section,
                                         'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section,
                                                 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section,
                                           'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(self.kubernetes_section,
                                               'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section,
                                          'airflow_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        self._validate()