Exemple #1
0
def configure_orm(disable_connection_pool=False):
    log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
    global engine
    global Session
    engine_args = {}

    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.configure_orm(): Using NullPool")
    elif 'sqlite' not in SQL_ALCHEMY_CONN:
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        try:
            pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE')
        except conf.AirflowConfigException:
            pool_size = 5

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        try:
            pool_recycle = conf.getint('core', 'SQL_ALCHEMY_POOL_RECYCLE')
        except conf.AirflowConfigException:
            pool_recycle = 1800

        log.info("settings.configure_orm(): Using pool settings. pool_size={}, "
                 "pool_recycle={}, pid={}".format(pool_size, pool_recycle, os.getpid()))
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use
    # us.
    engine_args['encoding'] = conf.get('core', 'SQL_ENGINE_ENCODING', fallback='utf-8')
    # For Python2 we get back a newstr and need a str
    engine_args['encoding'] = engine_args['encoding'].__str__()

    engine = create_engine(SQL_ALCHEMY_CONN, **engine_args)
    reconnect_timeout = conf.getint('core', 'SQL_ALCHEMY_RECONNECT_TIMEOUT')
    setup_event_handlers(engine, reconnect_timeout)

    Session = scoped_session(
        sessionmaker(autocommit=False,
                     autoflush=False,
                     bind=engine,
                     expire_on_commit=False))
Exemple #2
0
    def __init__(self, dag_id=None, subdir=None, test_mode=False, refresh_dags_every=10, *args, **kwargs):
        self.dag_id = dag_id
        self.subdir = subdir
        self.test_mode = test_mode
        self.refresh_dags_every = refresh_dags_every
        super(SchedulerJob, self).__init__(*args, **kwargs)

        self.heartrate = conf.getint("scheduler", "SCHEDULER_HEARTBEAT_SEC")
    def test_deprecated_options(self):
        # Guarantee we have a deprecated setting, so we test the deprecation
        # lookup even if we remove this explicit fallback
        conf.deprecated_options['celery'] = {
            'worker_concurrency': 'celeryd_concurrency',
        }

        # Remove it so we are sure we use the right setting
        conf.remove_option('celery', 'worker_concurrency')

        with self.assertWarns(DeprecationWarning):
            os.environ['AIRFLOW__CELERY__CELERYD_CONCURRENCY'] = '99'
            self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99)
            os.environ.pop('AIRFLOW__CELERY__CELERYD_CONCURRENCY')

        with self.assertWarns(DeprecationWarning):
            conf.set('celery', 'celeryd_concurrency', '99')
            self.assertEqual(conf.getint('celery', 'worker_concurrency'), 99)
            conf.remove_option('celery', 'celeryd_concurrency')
    def test_deprecated_options_cmd(self):
        # Guarantee we have a deprecated setting, so we test the deprecation
        # lookup even if we remove this explicit fallback
        conf.deprecated_options['celery'] = {'result_backend': 'celery_result_backend'}
        conf.as_command_stdout.add(('celery', 'celery_result_backend'))

        conf.remove_option('celery', 'result_backend')
        conf.set('celery', 'celery_result_backend_cmd', '/bin/echo 99')

        with self.assertWarns(DeprecationWarning):
            self.assertEquals(conf.getint('celery', 'result_backend'), 99)
Exemple #5
0
 def __init__(
         self,
         executor=executors.DEFAULT_EXECUTOR,
         heartrate=conf.getint('scheduler', 'JOB_HEARTBEAT_SEC'),
         *args, **kwargs):
     self.hostname = socket.gethostname()
     self.executor = executor
     self.executor_class = executor.__class__.__name__
     self.start_date = datetime.now()
     self.latest_heartbeat = datetime.now()
     self.heartrate = heartrate
     self.unixname = getpass.getuser()
     super(BaseJob, self).__init__(*args, **kwargs)
Exemple #6
0
    def __init__(
            self,
            dag_id=None,
            subdir=None,
            test_mode=False,
            refresh_dags_every=10,
            *args, **kwargs):
        self.dag_id = dag_id
        self.subdir = subdir
        self.test_mode = test_mode
        self.refresh_dags_every = refresh_dags_every
        super(MasterJob, self).__init__(*args, **kwargs)

        self.heartrate = conf.getint('master', 'MASTER_HEARTBEAT_SEC')
Exemple #7
0
def send_MIME_email(e_from, e_to, mime_msg):
    SMTP_HOST = conf.get("smtp", "SMTP_HOST")
    SMTP_PORT = conf.getint("smtp", "SMTP_PORT")
    SMTP_USER = conf.get("smtp", "SMTP_USER")
    SMTP_PASSWORD = conf.get("smtp", "SMTP_PASSWORD")
    SMTP_STARTTLS = conf.getboolean("smtp", "SMTP_STARTTLS")

    s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
    if SMTP_STARTTLS:
        s.starttls()
    if SMTP_USER and SMTP_PASSWORD:
        s.login(SMTP_USER, SMTP_PASSWORD)
    logging.info("Sent an alert email to " + str(e_to))
    s.sendmail(e_from, e_to, mime_msg.as_string())
    s.quit()
Exemple #8
0
def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
    SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
    SMTP_PORT = conf.getint('smtp', 'SMTP_PORT')
    SMTP_USER = conf.get('smtp', 'SMTP_USER')
    SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')
    SMTP_STARTTLS = conf.getboolean('smtp', 'SMTP_STARTTLS')

    if not dryrun:
        s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
        if SMTP_STARTTLS:
            s.starttls()
        if SMTP_USER and SMTP_PASSWORD:
            s.login(SMTP_USER, SMTP_PASSWORD)
        logging.info("Sent an alert email to " + str(e_to))
        s.sendmail(e_from, e_to, mime_msg.as_string())
        s.quit()
    def test_deprecated_options_cmd(self):
        # Guarantee we have a deprecated setting, so we test the deprecation
        # lookup even if we remove this explicit fallback
        conf.deprecated_options['celery'] = {'result_backend': 'celery_result_backend'}
        conf.as_command_stdout.add(('celery', 'celery_result_backend'))

        conf.remove_option('celery', 'result_backend')
        conf.set('celery', 'celery_result_backend_cmd', '/bin/echo 99')

        with self.assertWarns(DeprecationWarning):
            tmp = None
            if 'AIRFLOW__CELERY__RESULT_BACKEND' in os.environ:
                tmp = os.environ.pop('AIRFLOW__CELERY__RESULT_BACKEND')
            self.assertEqual(conf.getint('celery', 'result_backend'), 99)
            if tmp:
                os.environ['AIRFLOW__CELERY__RESULT_BACKEND'] = tmp
Exemple #10
0
    def __init__(
            self,
            dag_id=None,
            subdir=None,
            test_mode=False,
            refresh_dags_every=10,
            num_runs=None,
            *args, **kwargs):

        self.dag_id = dag_id
        self.subdir = subdir
        if test_mode:
            self.num_runs = 1
        else:
            self.num_runs = num_runs
        self.refresh_dags_every = refresh_dags_every
        super(SchedulerJob, self).__init__(*args, **kwargs)

        self.heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
Exemple #11
0
from builtins import range
from builtins import object
import logging

from airflow.utils import State
from airflow.configuration import conf

PARALLELISM = conf.getint("core", "PARALLELISM")


class BaseExecutor(object):
    def __init__(self, parallelism=PARALLELISM):
        """
        Class to derive in order to interface with executor-type systems
        like Celery, Mesos, Yarn and the likes.

        :param parallelism: how many jobs should run at one time. Set to
            ``0`` for infinity
        :type parallelism: int
        """
        self.parallelism = parallelism
        self.queued_tasks = {}
        self.running = {}
        self.event_buffer = {}

    def start(self):  # pragma: no cover
        """
        Executors may need to get things started. For example LocalExecutor
        starts N workers.
        """
        pass
Exemple #12
0
from airflow import settings
from airflow import utils
import socket
from airflow.utils import State


Base = models.Base
ID_LEN = models.ID_LEN

# Setting up a statsd client if needed
statsd = None
if conf.get('master', 'statsd_on'):
    from statsd import StatsClient
    statsd = StatsClient(
        host=conf.get('master', 'statsd_host'),
        port=conf.getint('master', 'statsd_port'),
        prefix='airflow')


class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """

    __tablename__ = "job"

    id = Column(Integer, primary_key=True)
    dag_id = Column(String(ID_LEN),)
def restart_workers(gunicorn_master_proc, num_workers_expected,
                    master_timeout):
    """
    Runs forever, monitoring the child processes of @gunicorn_master_proc and
    restarting workers occasionally.
    Each iteration of the loop traverses one edge of this state transition
    diagram, where each state (node) represents
    [ num_ready_workers_running / num_workers_running ]. We expect most time to
    be spent in [n / n]. `bs` is the setting webserver.worker_refresh_batch_size.
    The horizontal transition at ? happens after the new worker parses all the
    dags (so it could take a while!)
       V ────────────────────────────────────────────────────────────────────────┐
    [n / n] ──TTIN──> [ [n, n+bs) / n + bs ]  ────?───> [n + bs / n + bs] ──TTOU─┘
       ^                          ^───────────────┘
       │
       │      ┌────────────────v
       └──────┴────── [ [0, n) / n ] <─── start
    We change the number of workers by sending TTIN and TTOU to the gunicorn
    master process, which increases and decreases the number of child workers
    respectively. Gunicorn guarantees that on TTOU workers are terminated
    gracefully and that the oldest worker is terminated.
    """
    def wait_until_true(fn, timeout=0):
        """
        Sleeps until fn is true
        """
        start_time = time.time()
        while not fn():
            if 0 < timeout <= time.time() - start_time:
                raise AirflowWebServerTimeout(
                    "No response from gunicorn master within {0} seconds".
                    format(timeout))
            time.sleep(0.1)

    def start_refresh(gunicorn_master_proc):
        batch_size = conf.getint('webserver', 'worker_refresh_batch_size')
        log.debug('%s doing a refresh of %s workers', state, batch_size)
        sys.stdout.flush()
        sys.stderr.flush()

        excess = 0
        for _ in range(batch_size):
            gunicorn_master_proc.send_signal(signal.SIGTTIN)
            excess += 1
            wait_until_true(
                lambda: num_workers_expected + excess ==
                get_num_workers_running(gunicorn_master_proc), master_timeout)

    try:  # pylint: disable=too-many-nested-blocks
        wait_until_true(
            lambda: num_workers_expected == get_num_workers_running(
                gunicorn_master_proc), master_timeout)
        while True:
            num_workers_running = get_num_workers_running(gunicorn_master_proc)
            num_ready_workers_running = \
                get_num_ready_workers_running(gunicorn_master_proc)

            state = '[{0} / {1}]'.format(num_ready_workers_running,
                                         num_workers_running)

            # Whenever some workers are not ready, wait until all workers are ready
            if num_ready_workers_running < num_workers_running:
                log.debug('%s some workers are starting up, waiting...', state)
                sys.stdout.flush()
                time.sleep(1)

            # Kill a worker gracefully by asking gunicorn to reduce number of workers
            elif num_workers_running > num_workers_expected:
                excess = num_workers_running - num_workers_expected
                log.debug('%s killing %s workers', state, excess)

                for _ in range(excess):
                    gunicorn_master_proc.send_signal(signal.SIGTTOU)
                    excess -= 1
                    wait_until_true(
                        lambda: num_workers_expected + excess ==
                        get_num_workers_running(gunicorn_master_proc),
                        master_timeout)

            # Start a new worker by asking gunicorn to increase number of workers
            elif num_workers_running == num_workers_expected:
                refresh_interval = conf.getint('webserver',
                                               'worker_refresh_interval')
                log.debug('%s sleeping for %ss starting doing a refresh...',
                          state, refresh_interval)
                time.sleep(refresh_interval)
                start_refresh(gunicorn_master_proc)

            else:
                # num_ready_workers_running == num_workers_running < num_workers_expected
                log.error(("%s some workers seem to have died and gunicorn"
                           "did not restart them as expected"), state)
                time.sleep(10)
                if len(psutil.Process(gunicorn_master_proc.pid).children()
                       ) < num_workers_expected:
                    start_refresh(gunicorn_master_proc)
    except (AirflowWebServerTimeout, OSError) as err:
        log.error(err)
        log.error("Shutting down webserver")
        try:
            gunicorn_master_proc.terminate()
            gunicorn_master_proc.wait()
        finally:
            sys.exit(1)
Exemple #14
0
class CLIFactory:
    """
    Factory class which generates command line argument parser and holds information
    about all available Airflow commands
    """
    args = {
        # Shared
        'dag_id':
        Arg(("dag_id", ), "The id of the dag"),
        'task_id':
        Arg(("task_id", ), "The id of the task"),
        'execution_date':
        Arg(("execution_date", ),
            help="The execution date of the DAG",
            type=parsedate),
        'task_regex':
        Arg(("-t", "--task_regex"),
            "The regex to filter specific task_ids to backfill (optional)"),
        'subdir':
        Arg(("-sd", "--subdir"),
            "File location or directory from which to look for the dag. "
            "Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the "
            "value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' ",
            default=DAGS_FOLDER),
        'start_date':
        Arg(("-s", "--start_date"),
            "Override start_date YYYY-MM-DD",
            type=parsedate),
        'end_date':
        Arg(("-e", "--end_date"),
            "Override end_date YYYY-MM-DD",
            type=parsedate),
        'dry_run':
        Arg(("-dr", "--dry_run"), "Perform a dry run", "store_true"),
        'pid':
        Arg(("--pid", ), "PID file location", nargs='?'),
        'daemon':
        Arg(("-D", "--daemon"), "Daemonize instead of running "
            "in the foreground", "store_true"),
        'stderr':
        Arg(("--stderr", ), "Redirect stderr to this file"),
        'stdout':
        Arg(("--stdout", ), "Redirect stdout to this file"),
        'log_file':
        Arg(("-l", "--log-file"), "Location of the log file"),
        'yes':
        Arg(("-y", "--yes"),
            "Do not prompt to confirm reset. Use with care!",
            "store_true",
            default=False),
        'output':
        Arg(("--output", ),
            ("Output table format. The specified value is passed to "
             "the tabulate module (https://pypi.org/project/tabulate/). "
             "Valid values are: ({})".format("|".join(tabulate_formats))),
            choices=tabulate_formats,
            default="fancy_grid"),

        # list_dag_runs
        'no_backfill':
        Arg(("--no_backfill", ),
            "filter all the backfill dagruns given the dag id", "store_true"),
        'state':
        Arg(("--state", ),
            "Only list the dag runs corresponding to the state"),

        # list_jobs
        'limit':
        Arg(("--limit", ), "Return a limited number of records"),

        # backfill
        'mark_success':
        Arg(("-m", "--mark_success"),
            "Mark jobs as succeeded without running them", "store_true"),
        'verbose':
        Arg(("-v", "--verbose"), "Make logging output more verbose",
            "store_true"),
        'local':
        Arg(("-l", "--local"), "Run the task using the LocalExecutor",
            "store_true"),
        'donot_pickle':
        Arg(("-x", "--donot_pickle"),
            ("Do not attempt to pickle the DAG object to send over "
             "to the workers, just tell the workers to run their version "
             "of the code"), "store_true"),
        'bf_ignore_dependencies':
        Arg(("-i", "--ignore_dependencies"),
            ("Skip upstream tasks, run only the tasks "
             "matching the regexp. Only works in conjunction "
             "with task_regex"), "store_true"),
        'bf_ignore_first_depends_on_past':
        Arg(("-I", "--ignore_first_depends_on_past"),
            ("Ignores depends_on_past dependencies for the first "
             "set of tasks only (subsequent executions in the backfill "
             "DO respect depends_on_past)"), "store_true"),
        'pool':
        Arg(("--pool", ), "Resource pool to use"),
        'delay_on_limit':
        Arg(("--delay_on_limit", ),
            help=("Amount of time in seconds to wait when the limit "
                  "on maximum active dag runs (max_active_runs) has "
                  "been reached before trying to execute a dag run "
                  "again"),
            type=float,
            default=1.0),
        'reset_dag_run':
        Arg(("--reset_dagruns", ),
            ("if set, the backfill will delete existing "
             "backfill-related DAG runs and start "
             "anew with fresh, running DAG runs"), "store_true"),
        'rerun_failed_tasks':
        Arg(("--rerun_failed_tasks", ),
            ("if set, the backfill will auto-rerun "
             "all the failed tasks for the backfill date range "
             "instead of throwing exceptions"), "store_true"),
        'run_backwards':
        Arg((
            "-B",
            "--run_backwards",
        ), ("if set, the backfill will run tasks from the most "
            "recent day first.  if there are tasks that depend_on_past "
            "this option will throw an exception"), "store_true"),

        # list_tasks
        'tree':
        Arg(("-t", "--tree"), "Tree view", "store_true"),
        # list_dags
        'report':
        Arg(("-r", "--report"), "Show DagBag loading report", "store_true"),
        # clear
        'upstream':
        Arg(("-u", "--upstream"), "Include upstream tasks", "store_true"),
        'only_failed':
        Arg(("-f", "--only_failed"), "Only failed jobs", "store_true"),
        'only_running':
        Arg(("-r", "--only_running"), "Only running jobs", "store_true"),
        'downstream':
        Arg(("-d", "--downstream"), "Include downstream tasks", "store_true"),
        'exclude_subdags':
        Arg(("-x", "--exclude_subdags"), "Exclude subdags", "store_true"),
        'exclude_parentdag':
        Arg(("-xp", "--exclude_parentdag"),
            "Exclude ParentDAGS if the task cleared is a part of a SubDAG",
            "store_true"),
        'dag_regex':
        Arg(("-dx", "--dag_regex"),
            "Search dag_id as regex instead of exact string", "store_true"),
        # show_dag
        'save':
        Arg(("-s", "--save"), "Saves the result to the indicated file.\n"
            "\n"
            "The file format is determined by the file extension. For more information about supported "
            "format, see: https://www.graphviz.org/doc/info/output.html\n"
            "\n"
            "If you want to create a PNG file then you should execute the following command:\n"
            "airflow dags show <DAG_ID> --save output.png\n"
            "\n"
            "If you want to create a DOT file then you should execute the following command:\n"
            "airflow dags show <DAG_ID> --save output.dot\n"),
        'imgcat':
        Arg(("--imgcat", ), "Displays graph using the imgcat tool. \n"
            "\n"
            "For more information, see: https://www.iterm2.com/documentation-images.html",
            action='store_true'),
        # trigger_dag
        'run_id':
        Arg(("-r", "--run_id"), "Helps to identify this run"),
        'conf':
        Arg(('-c', '--conf'),
            "JSON string that gets pickled into the DagRun's conf attribute"),
        'exec_date':
        Arg(("-e", "--exec_date"),
            help="The execution date of the DAG",
            type=parsedate),
        # pool
        'pool_name':
        Arg(("pool", ), metavar='NAME', help="Pool name"),
        'pool_slots':
        Arg(("slots", ), type=int, help="Pool slots"),
        'pool_description':
        Arg(("description", ), help="Pool description"),
        'pool_import':
        Arg(("file", ), metavar="FILEPATH",
            help="Import pools from JSON file"),
        'pool_export':
        Arg(("file", ),
            metavar="FILEPATH",
            help="Export all pools to JSON file"),
        # variables
        'var':
        Arg(("key", ), help="Variable key"),
        'var_value':
        Arg(("value", ), metavar='VALUE', help="Variable value"),
        'default':
        Arg(("-d", "--default"),
            metavar="VAL",
            default=None,
            help="Default value returned if variable does not exist"),
        'json':
        Arg(("-j", "--json"),
            help="Deserialize JSON variable",
            action="store_true"),
        'var_import':
        Arg(("file", ), help="Import variables from JSON file"),
        'var_export':
        Arg(("file", ), help="Export all variables to JSON file"),
        # kerberos
        'principal':
        Arg(("principal", ), "kerberos principal", nargs='?'),
        'keytab':
        Arg(("-kt", "--keytab"),
            "keytab",
            nargs='?',
            default=conf.get('kerberos', 'keytab')),
        # run
        # TODO(aoen): "force" is a poor choice of name here since it implies it overrides
        # all dependencies (not just past success), e.g. the ignore_depends_on_past
        # dependency. This flag should be deprecated and renamed to 'ignore_ti_state' and
        # the "ignore_all_dependencies" command should be called the"force" command
        # instead.
        'interactive':
        Arg(('-int', '--interactive'),
            help='Do not capture standard output and error streams '
            '(useful for interactive debugging)',
            action='store_true'),
        'force':
        Arg((
            "-f", "--force"
        ), "Ignore previous task instance state, rerun regardless if task already "
            "succeeded/failed", "store_true"),
        'raw':
        Arg(("-r", "--raw"), argparse.SUPPRESS, "store_true"),
        'ignore_all_dependencies':
        Arg((
            "-A", "--ignore_all_dependencies"
        ), "Ignores all non-critical dependencies, including ignore_ti_state and "
            "ignore_task_deps", "store_true"),
        # TODO(aoen): ignore_dependencies is a poor choice of name here because it is too
        # vague (e.g. a task being in the appropriate state to be run is also a dependency
        # but is not ignored by this flag), the name 'ignore_task_dependencies' is
        # slightly better (as it ignores all dependencies that are specific to the task),
        # so deprecate the old command name and use this instead.
        'ignore_dependencies':
        Arg((
            "-i", "--ignore_dependencies"
        ), "Ignore task-specific dependencies, e.g. upstream, depends_on_past, and "
            "retry delay dependencies", "store_true"),
        'ignore_depends_on_past':
        Arg(("-I", "--ignore_depends_on_past"),
            "Ignore depends_on_past dependencies (but respect "
            "upstream dependencies)", "store_true"),
        'ship_dag':
        Arg(("--ship_dag", ),
            "Pickles (serializes) the DAG and ships it to the worker",
            "store_true"),
        'pickle':
        Arg(("-p", "--pickle"),
            "Serialized pickle object of the entire dag (used internally)"),
        'job_id':
        Arg(("-j", "--job_id"), argparse.SUPPRESS),
        'cfg_path':
        Arg(("--cfg_path", ),
            "Path to config file to use instead of airflow.cfg"),
        # webserver
        'port':
        Arg(("-p", "--port"),
            default=conf.get('webserver', 'WEB_SERVER_PORT'),
            type=int,
            help="The port on which to run the server"),
        'ssl_cert':
        Arg(("--ssl_cert", ),
            default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'),
            help="Path to the SSL certificate for the webserver"),
        'ssl_key':
        Arg(("--ssl_key", ),
            default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'),
            help="Path to the key to use with the SSL certificate"),
        'workers':
        Arg(("-w", "--workers"),
            default=conf.get('webserver', 'WORKERS'),
            type=int,
            help="Number of workers to run the webserver on"),
        'workerclass':
        Arg(("-k", "--workerclass"),
            default=conf.get('webserver', 'WORKER_CLASS'),
            choices=['sync', 'eventlet', 'gevent', 'tornado'],
            help="The worker class to use for Gunicorn"),
        'worker_timeout':
        Arg(("-t", "--worker_timeout"),
            default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'),
            type=int,
            help="The timeout for waiting on webserver workers"),
        'hostname':
        Arg(("-hn", "--hostname"),
            default=conf.get('webserver', 'WEB_SERVER_HOST'),
            help="Set the hostname on which to run the web server"),
        'debug':
        Arg(("-d", "--debug"),
            "Use the server that ships with Flask in debug mode",
            "store_true"),
        'access_logfile':
        Arg(("-A", "--access_logfile"),
            default=conf.get('webserver', 'ACCESS_LOGFILE'),
            help=
            "The logfile to store the webserver access log. Use '-' to print to "
            "stderr"),
        'error_logfile':
        Arg(("-E", "--error_logfile"),
            default=conf.get('webserver', 'ERROR_LOGFILE'),
            help=
            "The logfile to store the webserver error log. Use '-' to print to "
            "stderr"),
        # scheduler
        'dag_id_opt':
        Arg(("-d", "--dag_id"), help="The id of the dag to run"),
        'num_runs':
        Arg(("-n", "--num_runs"),
            default=conf.getint('scheduler', 'num_runs'),
            type=int,
            help="Set the number of runs to execute before exiting"),
        # worker
        'do_pickle':
        Arg(("-p", "--do_pickle"),
            default=False,
            help=(
                "Attempt to pickle the DAG object to send over "
                "to the workers, instead of letting workers run their version "
                "of the code"),
            action="store_true"),
        'queues':
        Arg(("-q", "--queues"),
            help="Comma delimited list of queues to serve",
            default=conf.get('celery', 'DEFAULT_QUEUE')),
        'concurrency':
        Arg(("-c", "--concurrency"),
            type=int,
            help="The number of worker processes",
            default=conf.get('celery', 'worker_concurrency')),
        'celery_hostname':
        Arg(("-cn", "--celery_hostname"),
            help=("Set the hostname of celery worker "
                  "if you have multiple workers on a single machine")),
        # flower
        'broker_api':
        Arg(("-a", "--broker_api"), help="Broker api"),
        'flower_hostname':
        Arg(("-hn", "--hostname"),
            default=conf.get('celery', 'FLOWER_HOST'),
            help="Set the hostname on which to run the server"),
        'flower_port':
        Arg(("-p", "--port"),
            default=conf.get('celery', 'FLOWER_PORT'),
            type=int,
            help="The port on which to run the server"),
        'flower_conf':
        Arg(("-fc", "--flower_conf"), help="Configuration file for flower"),
        'flower_url_prefix':
        Arg(("-u", "--url_prefix"),
            default=conf.get('celery', 'FLOWER_URL_PREFIX'),
            help="URL prefix for Flower"),
        'flower_basic_auth':
        Arg(("-ba", "--basic_auth"),
            default=conf.get('celery', 'FLOWER_BASIC_AUTH'),
            help=(
                "Securing Flower with Basic Authentication. "
                "Accepts user:password pairs separated by a comma. "
                "Example: flower_basic_auth = user1:password1,user2:password2"
            )),
        'task_params':
        Arg(("-tp", "--task_params"),
            help="Sends a JSON params dict to the task"),
        'post_mortem':
        Arg(
            ("-pm", "--post_mortem"),
            action="store_true",
            help="Open debugger on uncaught exception",
        ),
        # connections
        'conn_id':
        Arg(('conn_id', ),
            help='Connection id, required to add/delete a connection',
            type=str),
        'conn_uri':
        Arg(('--conn_uri', ),
            help=
            'Connection URI, required to add a connection without conn_type',
            type=str),
        'conn_type':
        Arg(('--conn_type', ),
            help=
            'Connection type, required to add a connection without conn_uri',
            type=str),
        'conn_host':
        Arg(('--conn_host', ),
            help='Connection host, optional when adding a connection',
            type=str),
        'conn_login':
        Arg(('--conn_login', ),
            help='Connection login, optional when adding a connection',
            type=str),
        'conn_password':
        Arg(('--conn_password', ),
            help='Connection password, optional when adding a connection',
            type=str),
        'conn_schema':
        Arg(('--conn_schema', ),
            help='Connection schema, optional when adding a connection',
            type=str),
        'conn_port':
        Arg(('--conn_port', ),
            help='Connection port, optional when adding a connection',
            type=str),
        'conn_extra':
        Arg(('--conn_extra', ),
            help='Connection `Extra` field, optional when adding a connection',
            type=str),
        # users
        'username':
        Arg(('--username', ),
            help='Username of the user',
            required=True,
            type=str),
        'username_optional':
        Arg(('--username', ), help='Username of the user', type=str),
        'firstname':
        Arg(('--firstname', ),
            help='First name of the user',
            required=True,
            type=str),
        'lastname':
        Arg(('--lastname', ),
            help='Last name of the user',
            required=True,
            type=str),
        'role':
        Arg(
            ('--role', ),
            help='Role of the user. Existing roles include Admin, '
            'User, Op, Viewer, and Public',
            required=True,
            type=str,
        ),
        'email':
        Arg(('--email', ), help='Email of the user', required=True, type=str),
        'email_optional':
        Arg(('--email', ), help='Email of the user', type=str),
        'password':
        Arg(('--password', ),
            help='Password of the user, required to create a user '
            'without --use_random_password',
            type=str),
        'use_random_password':
        Arg(('--use_random_password', ),
            help='Do not prompt for password. Use random string instead.'
            ' Required to create a user without --password ',
            default=False,
            action='store_true'),
        'user_import':
        Arg(
            ("import", ),
            metavar="FILEPATH",
            help="Import users from JSON file. Example format::\n" +
            textwrap.indent(
                textwrap.dedent('''
                    [
                        {
                            "email": "*****@*****.**",
                            "firstname": "Jon",
                            "lastname": "Doe",
                            "roles": ["Public"],
                            "username": "******"
                        }
                    ]'''), " " * 4),
        ),
        'user_export':
        Arg(("export", ),
            metavar="FILEPATH",
            help="Export all users to JSON file"),
        # roles
        'create_role':
        Arg(('-c', '--create'), help='Create a new role', action='store_true'),
        'list_roles':
        Arg(('-l', '--list'), help='List roles', action='store_true'),
        'roles':
        Arg(('role', ), help='The name of a role', nargs='*'),
        'autoscale':
        Arg(('-a', '--autoscale'),
            help="Minimum and Maximum number of worker to autoscale"),
        'skip_serve_logs':
        Arg(("-s", "--skip_serve_logs"),
            default=False,
            help="Don't start the serve logs process along with the workers",
            action="store_true"),
    }
    DAGS_SUBCOMMANDS = (
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.dag_command.dag_list_dags'),
            'name':
            'list',
            'help':
            "List all the DAGs",
            'args': ('subdir', 'report'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.dag_command.dag_list_dag_runs'),
            'name':
            'list_runs',
            'help':
            "List dag runs given a DAG id. If state option is given, it will only "
            "search for all the dagruns with the given state. "
            "If no_backfill option is given, it will filter out "
            "all backfill dagruns for given dag id",
            'args': (
                'dag_id',
                'no_backfill',
                'state',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.dag_command.dag_list_jobs'),
            'name':
            'list_jobs',
            'help':
            "List the jobs",
            'args': (
                'dag_id_opt',
                'state',
                'limit',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_state'),
            'name': 'state',
            'help': "Get the status of a dag run",
            'args': ('dag_id', 'execution_date', 'subdir'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.dag_command.dag_next_execution'),
            'name':
            'next_execution',
            'help':
            "Get the next execution datetime of a DAG",
            'args': ('dag_id', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_pause'),
            'name': 'pause',
            'help': 'Pause a DAG',
            'args': ('dag_id', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_unpause'),
            'name':
            'unpause',
            'help':
            'Resume a paused DAG',
            'args': ('dag_id', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_trigger'),
            'name':
            'trigger',
            'help':
            'Trigger a DAG run',
            'args': ('dag_id', 'subdir', 'run_id', 'conf', 'exec_date'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_delete'),
            'name':
            'delete',
            'help':
            "Delete all DB records related to the specified DAG",
            'args': ('dag_id', 'yes'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_show'),
            'name': 'show',
            'help': "Displays DAG's tasks with their dependencies",
            'args': (
                'dag_id',
                'subdir',
                'save',
                'imgcat',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'),
            'name':
            'backfill',
            'help':
            "Run subsections of a DAG for a specified date range. "
            "If reset_dag_run option is used,"
            " backfill will first prompt users whether airflow "
            "should clear all the previous dag_run and task_instances "
            "within the backfill date range. "
            "If rerun_failed_tasks is used, backfill "
            "will auto re-run the previous failed task instances"
            " within the backfill date range",
            'args':
            ('dag_id', 'task_regex', 'start_date', 'end_date', 'mark_success',
             'local', 'donot_pickle', 'yes', 'bf_ignore_dependencies',
             'bf_ignore_first_depends_on_past', 'subdir', 'pool',
             'delay_on_limit', 'dry_run', 'verbose', 'conf', 'reset_dag_run',
             'rerun_failed_tasks', 'run_backwards'),
        },
    )
    TASKS_COMMANDS = (
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_list'),
            'name':
            'list',
            'help':
            "List the tasks within a DAG",
            'args': ('dag_id', 'tree', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_clear'),
            'name':
            'clear',
            'help':
            "Clear a set of task instance, as if they never ran",
            'args':
            ('dag_id', 'task_regex', 'start_date', 'end_date', 'subdir',
             'upstream', 'downstream', 'yes', 'only_failed', 'only_running',
             'exclude_subdags', 'exclude_parentdag', 'dag_regex'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_state'),
            'name':
            'state',
            'help':
            "Get the status of a task instance",
            'args': ('dag_id', 'task_id', 'execution_date', 'subdir'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.task_command.task_failed_deps'),
            'name':
            'failed_deps',
            'help':
            ("Returns the unmet dependencies for a task instance from the perspective "
             "of the scheduler. In other words, why a task instance doesn't get "
             "scheduled and then queued by the scheduler, and then run by an "
             "executor)"),
            'args': ('dag_id', 'task_id', 'execution_date', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_render'),
            'name':
            'render',
            'help':
            "Render a task instance's template(s)",
            'args': ('dag_id', 'task_id', 'execution_date', 'subdir'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_run'),
            'name':
            'run',
            'help':
            "Run a single task instance",
            'args': (
                'dag_id',
                'task_id',
                'execution_date',
                'subdir',
                'mark_success',
                'force',
                'pool',
                'cfg_path',
                'local',
                'raw',
                'ignore_all_dependencies',
                'ignore_dependencies',
                'ignore_depends_on_past',
                'ship_dag',
                'pickle',
                'job_id',
                'interactive',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.task_command.task_test'),
            'name':
            'test',
            'help':
            ("Test a task instance. This will run a task without checking for "
             "dependencies or recording its state in the database"),
            'args': ('dag_id', 'task_id', 'execution_date', 'subdir',
                     'dry_run', 'task_params', 'post_mortem'),
        },
    )
    POOLS_COMMANDS = (
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_list'),
            'name':
            'list',
            'help':
            'List pools',
            'args': ('output', ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_get'),
            'name': 'get',
            'help': 'Get pool size',
            'args': (
                'pool_name',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_set'),
            'name': 'set',
            'help': 'Configure pool',
            'args': (
                'pool_name',
                'pool_slots',
                'pool_description',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_delete'),
            'name':
            'delete',
            'help':
            'Delete pool',
            'args': (
                'pool_name',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_import'),
            'name':
            'import',
            'help':
            'Import pools',
            'args': (
                'pool_import',
                'output',
            ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.pool_command.pool_export'),
            'name':
            'export',
            'help':
            'Export all pools',
            'args': (
                'pool_export',
                'output',
            ),
        },
    )
    VARIABLES_COMMANDS = (
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_list'),
            'name':
            'list',
            'help':
            'List variables',
            'args': (),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_get'),
            'name':
            'get',
            'help':
            'Get variable',
            'args': ('var', 'json', 'default'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_set'),
            'name':
            'set',
            'help':
            'Set variable',
            'args': ('var', 'var_value', 'json'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_delete'),
            'name':
            'delete',
            'help':
            'Delete variable',
            'args': ('var', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_import'),
            'name':
            'import',
            'help':
            'Import variables',
            'args': ('var_import', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.variable_command.variables_export'),
            'name':
            'export',
            'help':
            'Export all variables',
            'args': ('var_export', ),
        },
    )
    DB_COMMANDS = (
        {
            'func':
            lazy_load_command('airflow.cli.commands.db_command.initdb'),
            'name': 'init',
            'help': "Initialize the metadata database",
            'args': (),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.db_command.resetdb'),
            'name': 'reset',
            'help': "Burn down and rebuild the metadata database",
            'args': ('yes', ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.db_command.upgradedb'),
            'name': 'upgrade',
            'help': "Upgrade the metadata database to latest version",
            'args': tuple(),
        },
        {
            'func': lazy_load_command('airflow.cli.commands.db_command.shell'),
            'name': 'shell',
            'help': "Runs a shell to access the database",
            'args': tuple(),
        },
    )
    CONNECTIONS_COMMANDS = (
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.connection_command.connections_list'),
            'name':
            'list',
            'help':
            'List connections',
            'args': ('output', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.connection_command.connections_add'),
            'name':
            'add',
            'help':
            'Add a connection',
            'args': ('conn_id', 'conn_uri', 'conn_extra') +
            tuple(alternative_conn_specs),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.connection_command.connections_delete'),
            'name':
            'delete',
            'help':
            'Delete a connection',
            'args': ('conn_id', ),
        },
    )
    USERS_COMMANDSS = (
        {
            'func':
            lazy_load_command('airflow.cli.commands.user_command.users_list'),
            'name':
            'list',
            'help':
            'List users',
            'args': ('output', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.user_command.users_create'),
            'name':
            'create',
            'help':
            'Create a user',
            'args': ('role', 'username', 'email', 'firstname', 'lastname',
                     'password', 'use_random_password')
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.user_command.users_delete'),
            'name':
            'delete',
            'help':
            'Delete a user',
            'args': ('username', ),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.user_command.add_role'),
            'name': 'add_role',
            'help': 'Add role to a user',
            'args': ('username_optional', 'email_optional', 'role'),
        },
        {
            'func':
            lazy_load_command('airflow.cli.commands.user_command.remove_role'),
            'name':
            'remove_role',
            'help':
            'Remove role from a user',
            'args': ('username_optional', 'email_optional', 'role'),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.user_command.users_import'),
            'name':
            'import',
            'help':
            'Import users',
            'args': ('user_import', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.user_command.users_export'),
            'name':
            'export',
            'help':
            'Export all users',
            'args': ('user_export', ),
        },
    )
    ROLES_COMMANDS = (
        {
            'func':
            lazy_load_command('airflow.cli.commands.role_command.roles_list'),
            'name':
            'list',
            'help':
            'List roles',
            'args': ('output', ),
        },
        {
            'func':
            lazy_load_command(
                'airflow.cli.commands.role_command.roles_create'),
            'name':
            'create',
            'help':
            'Create role',
            'args': ('roles', ),
        },
    )
    subparsers = [
        {
            'help': 'List and manage DAGs',
            'name': 'dags',
            'subcommands': DAGS_SUBCOMMANDS,
        },
        {
            'help': 'List and manage tasks',
            'name': 'tasks',
            'subcommands': TASKS_COMMANDS,
        },
        {
            'help': "CRUD operations on pools",
            'name': 'pools',
            'subcommands': POOLS_COMMANDS,
        },
        {
            'help':
            "CRUD operations on variables",
            'name':
            'variables',
            'subcommands':
            VARIABLES_COMMANDS,
            "args": ('set', 'get', 'json', 'default', 'var_import',
                     'var_export', 'var_delete'),
        },
        {
            'help': "Database operations",
            'name': 'db',
            'subcommands': DB_COMMANDS,
        },
        {
            'name':
            'kerberos',
            'func':
            lazy_load_command(
                'airflow.cli.commands.kerberos_command.kerberos'),
            'help':
            "Start a kerberos ticket renewer",
            'args': ('principal', 'keytab', 'pid', 'daemon', 'stdout',
                     'stderr', 'log_file'),
        },
        {
            'name':
            'webserver',
            'func':
            lazy_load_command(
                'airflow.cli.commands.webserver_command.webserver'),
            'help':
            "Start a Airflow webserver instance",
            'args':
            ('port', 'workers', 'workerclass', 'worker_timeout', 'hostname',
             'pid', 'daemon', 'stdout', 'stderr', 'access_logfile',
             'error_logfile', 'log_file', 'ssl_cert', 'ssl_key', 'debug'),
        },
        {
            'name':
            'scheduler',
            'func':
            lazy_load_command(
                'airflow.cli.commands.scheduler_command.scheduler'),
            'help':
            "Start a scheduler instance",
            'args': ('dag_id_opt', 'subdir', 'num_runs', 'do_pickle', 'pid',
                     'daemon', 'stdout', 'stderr', 'log_file'),
        },
        {
            'name':
            'version',
            'func':
            lazy_load_command('airflow.cli.commands.version_command.version'),
            'help':
            "Show the version",
            'args':
            tuple(),
        },
        {
            'help': "List/Add/Delete connections",
            'name': 'connections',
            'subcommands': CONNECTIONS_COMMANDS,
        },
        {
            'help': "CRUD operations on users",
            'name': 'users',
            'subcommands': USERS_COMMANDSS,
        },
        {
            'help': 'Create/List roles',
            'name': 'roles',
            'subcommands': ROLES_COMMANDS,
        },
        {
            'name':
            'sync_perm',
            'func':
            lazy_load_command(
                'airflow.cli.commands.sync_perm_command.sync_perm'),
            'help':
            "Update permissions for existing roles and DAGs",
            'args':
            tuple(),
        },
        {
            'name':
            'rotate_fernet_key',
            'func':
            lazy_load_command(
                'airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key'
            ),
            'help':
            'Rotate all encrypted connection credentials and variables; see '
            'https://airflow.readthedocs.io/en/stable/howto/secure-connections.html'
            '#rotating-encryption-keys',
            'args': (),
        },
        {
            'name':
            'config',
            'func':
            lazy_load_command(
                'airflow.cli.commands.config_command.show_config'),
            'help':
            'Show current application configuration',
            'args': (),
        },
    ]
    if conf.get("core",
                "EXECUTOR") == ExecutorLoader.CELERY_EXECUTOR or BUILD_DOCS:
        subparsers.append({
            "help":
            "Start celery components",
            "name":
            "celery",
            "subcommands": (
                {
                    'name':
                    'worker',
                    'func':
                    lazy_load_command(
                        'airflow.cli.commands.celery_command.worker'),
                    'help':
                    "Start a Celery worker node",
                    'args':
                    ('do_pickle', 'queues', 'concurrency', 'celery_hostname',
                     'pid', 'daemon', 'stdout', 'stderr', 'log_file',
                     'autoscale', 'skip_serve_logs'),
                },
                {
                    'name':
                    'flower',
                    'func':
                    lazy_load_command(
                        'airflow.cli.commands.celery_command.flower'),
                    'help':
                    "Start a Celery Flower",
                    'args':
                    ('flower_hostname', 'flower_port', 'flower_conf',
                     'flower_url_prefix', 'flower_basic_auth', 'broker_api',
                     'pid', 'daemon', 'stdout', 'stderr', 'log_file'),
                },
            )
        })
    subparsers_dict = {
        sp.get('name') or sp['func'].__name__: sp
        for sp in subparsers
    }  # type: ignore
    dag_subparsers = ('list_tasks', 'backfill', 'test', 'run', 'pause',
                      'unpause', 'list_dag_runs')

    @classmethod
    def get_parser(cls, dag_parser=False):
        """Creates and returns command line argument parser"""
        class DefaultHelpParser(argparse.ArgumentParser):
            """Override argparse.ArgumentParser.error and use print_help instead of print_usage"""
            def error(self, message):
                self.print_help()
                self.exit(
                    2, '\n{} command error: {}, see help above.\n'.format(
                        self.prog, message))

        parser = DefaultHelpParser()
        subparsers = parser.add_subparsers(help='sub-command help',
                                           dest='subcommand')
        subparsers.required = True

        subparser_list = cls.dag_subparsers if dag_parser else cls.subparsers_dict.keys(
        )
        for sub in sorted(subparser_list):
            sub = cls.subparsers_dict[sub]
            cls._add_subcommand(subparsers, sub)
        return parser

    @classmethod
    def sort_args(cls, args: Arg):
        """
        Sort subcommand optional args, keep positional args
        """
        def partition(pred, iterable):
            """
            Use a predicate to partition entries into false entries and true entries
            """
            iter_1, iter_2 = tee(iterable)
            return filterfalse(pred, iter_1), filter(pred, iter_2)

        def get_long_option(arg):
            """
            Get long option from Arg.flags
            """
            return cls.args[arg].flags[0] if len(
                cls.args[arg].flags) == 1 else cls.args[arg].flags[1]

        positional, optional = partition(
            lambda x: cls.args[x].flags[0].startswith("-"), args)
        yield from positional
        yield from sorted(optional, key=lambda x: get_long_option(x).lower())

    @classmethod
    def _add_subcommand(cls, subparsers, sub):
        dag_parser = False
        sub_proc = subparsers.add_parser(
            sub.get('name') or sub['func'].__name__,
            help=sub['help']  # type: ignore
        )
        sub_proc.formatter_class = RawTextHelpFormatter

        subcommands = sub.get('subcommands', [])
        if subcommands:
            sub_subparsers = sub_proc.add_subparsers(dest='subcommand')
            sub_subparsers.required = True
            for command in sorted(subcommands, key=lambda x: x['name']):
                cls._add_subcommand(sub_subparsers, command)
        else:
            for arg in cls.sort_args(sub['args']):
                if 'dag_id' in arg and dag_parser:
                    continue
                arg = cls.args[arg]
                kwargs = {
                    f: v
                    for f, v in vars(arg).items() if f != 'flags' and v
                }
                sub_proc.add_argument(*arg.flags, **kwargs)
            sub_proc.set_defaults(func=sub['func'])
Exemple #15
0
def configure_orm(disable_connection_pool=False):
    log.debug("Setting up DB connection pool (PID %s)" % os.getpid())
    global engine
    global Session
    engine_args = {}

    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.configure_orm(): Using NullPool")
    elif 'sqlite' not in SQL_ALCHEMY_CONN:
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of “sleeping” connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('core',
                                   'SQL_ALCHEMY_MAX_OVERFLOW',
                                   fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('core',
                                   'SQL_ALCHEMY_POOL_RECYCLE',
                                   fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like “SELECT 1”, but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('core',
                                        'SQL_ALCHEMY_POOL_PRE_PING',
                                        fallback=True)

        log.info(
            "settings.configure_orm(): Using pool settings. pool_size={}, max_overflow={}, "
            "pool_recycle={}, pid={}".format(pool_size, max_overflow,
                                             pool_recycle, os.getpid()))
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # Allow the user to specify an encoding for their DB otherwise default
    # to utf-8 so jobs & users with non-latin1 characters can still use
    # us.
    engine_args['encoding'] = conf.get('core',
                                       'SQL_ENGINE_ENCODING',
                                       fallback='utf-8')
    # For Python2 we get back a newstr and need a str
    engine_args['encoding'] = engine_args['encoding'].__str__()

    if conf.has_option('core', 'sql_alchemy_connect_args'):
        connect_args = import_string(
            conf.get('core', 'sql_alchemy_connect_args'))
    else:
        connect_args = {}

    engine = create_engine(SQL_ALCHEMY_CONN,
                           connect_args=connect_args,
                           **engine_args)
    setup_event_handlers(engine)

    Session = scoped_session(
        sessionmaker(autocommit=False,
                     autoflush=False,
                     bind=engine,
                     expire_on_commit=False))
Exemple #16
0
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.kube_env_vars = configuration_dict.get(
            'kubernetes_environment_variables', {})
        self.env_from_configmap_ref = configuration.get(
            self.kubernetes_section, 'env_from_configmap_ref')
        self.env_from_secret_ref = configuration.get(self.kubernetes_section,
                                                     'env_from_secret_ref')
        self.airflow_home = settings.AIRFLOW_HOME
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section,
                                                'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(self.kubernetes_section,
                                                      'worker_container_tag')
        self.kube_image = '{}:{}'.format(self.worker_container_repository,
                                         self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy")
        self.kube_node_selectors = configuration_dict.get(
            'kubernetes_node_selectors', {})
        self.kube_annotations = configuration_dict.get(
            'kubernetes_annotations', {})
        self.kube_labels = configuration_dict.get('kubernetes_labels', {})
        self.delete_worker_pods = conf.getboolean(self.kubernetes_section,
                                                  'delete_worker_pods')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section,
                                           'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section,
                                             'dags_in_image')

        # Run as user for pod security context
        self.worker_run_as_user = conf.get(self.kubernetes_section,
                                           'run_as_user')
        self.worker_fs_group = conf.get(self.kubernetes_section, 'fs_group')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(
            self.kubernetes_section, 'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section,
                                                'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(
            self.kubernetes_section, 'git_ssh_known_hosts_configmap_name')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section,
                                          'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section,
                                          'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(self.kubernetes_section,
                                            'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(self.kubernetes_section,
                                            'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section,
                                         'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section,
                                         'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section,
                                                 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section,
                                           'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(self.kubernetes_section,
                                               'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section,
                                          'airflow_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        kube_client_request_args = conf.get(self.kubernetes_section,
                                            'kube_client_request_args')
        if kube_client_request_args:
            self.kube_client_request_args = json.loads(
                kube_client_request_args)
            if self.kube_client_request_args['_request_timeout'] and \
                    isinstance(self.kube_client_request_args['_request_timeout'], list):
                self.kube_client_request_args['_request_timeout'] = \
                    tuple(self.kube_client_request_args['_request_timeout'])
        else:
            self.kube_client_request_args = {}
        self._validate()
Exemple #17
0
    def __init__(
            self,
            task_id,  # type: str
            owner=conf.get('operators', 'DEFAULT_OWNER'),  # type: str
            email=None,  # type: Optional[Union[str, Iterable[str]]]
            email_on_retry=True,  # type: bool
            email_on_failure=True,  # type: bool
            retries=conf.getint('core', 'default_task_retries',
                                fallback=0),  # type: int
            retry_delay=timedelta(seconds=300),  # type: timedelta
            retry_exponential_backoff=False,  # type: bool
            max_retry_delay=None,  # type: Optional[datetime]
            start_date=None,  # type: Optional[datetime]
            end_date=None,  # type: Optional[datetime]
            schedule_interval=None,  # not hooked as of now
            depends_on_past=False,  # type: bool
            wait_for_downstream=False,  # type: bool
            dag=None,  # type: Optional[DAG]
            params=None,  # type: Optional[Dict]
            default_args=None,  # type: Optional[Dict]
            priority_weight=1,  # type: int
            weight_rule=WeightRule.DOWNSTREAM,  # type: str
            queue=conf.get('celery', 'default_queue'),  # type: str
            pool=Pool.DEFAULT_POOL_NAME,  # type: str
            sla=None,  # type: Optional[timedelta]
            execution_timeout=None,  # type: Optional[timedelta]
            on_failure_callback=None,  # type: Optional[Callable]
            on_success_callback=None,  # type: Optional[Callable]
            on_retry_callback=None,  # type: Optional[Callable]
            trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
            resources=None,  # type: Optional[Dict]
            run_as_user=None,  # type: Optional[str]
            task_concurrency=None,  # type: Optional[int]
            executor_config=None,  # type: Optional[Dict]
            do_xcom_push=True,  # type: bool
            inlets=None,  # type: Optional[Dict]
            outlets=None,  # type: Optional[Dict]
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead", self)
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(
            **resources) if resources is not None else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        self.subdag = None  # type: Optional[DAG]

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: Iterable[DataSet]
        self.outlets = []  # type: Iterable[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
Exemple #18
0
class DagRun(Base, LoggingMixin):
    """
    DagRun describes an instance of a Dag. It can be created
    by the scheduler (for regular runs) or by an external trigger
    """

    __tablename__ = "dag_run"

    id = Column(Integer, primary_key=True)
    dag_id = Column(String(ID_LEN))
    execution_date = Column(UtcDateTime, default=timezone.utcnow)
    start_date = Column(UtcDateTime, default=timezone.utcnow)
    end_date = Column(UtcDateTime)
    _state = Column('state', String(50), default=State.RUNNING)
    run_id = Column(String(ID_LEN))
    creating_job_id = Column(Integer)
    external_trigger = Column(Boolean, default=True)
    run_type = Column(String(50), nullable=False)
    conf = Column(PickleType)
    # When a scheduler last attempted to schedule TIs for this DagRun
    last_scheduling_decision = Column(UtcDateTime)
    dag_hash = Column(String(32))

    dag = None

    __table_args__ = (
        Index('dag_id_state', dag_id, _state),
        UniqueConstraint('dag_id', 'execution_date'),
        UniqueConstraint('dag_id', 'run_id'),
        Index('idx_last_scheduling_decision', last_scheduling_decision),
    )

    task_instances = relationship(
        TI,
        primaryjoin=and_(TI.dag_id == dag_id, TI.execution_date == execution_date),  # type: ignore
        foreign_keys=(dag_id, execution_date),
        backref=backref('dag_run', uselist=False),
    )

    DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
        'scheduler',
        'max_dagruns_per_loop_to_schedule',
        fallback=20,
    )

    def __init__(
        self,
        dag_id: Optional[str] = None,
        run_id: Optional[str] = None,
        execution_date: Optional[datetime] = None,
        start_date: Optional[datetime] = None,
        external_trigger: Optional[bool] = None,
        conf: Optional[Any] = None,
        state: Optional[str] = None,
        run_type: Optional[str] = None,
        dag_hash: Optional[str] = None,
        creating_job_id: Optional[int] = None,
    ):
        self.dag_id = dag_id
        self.run_id = run_id
        self.execution_date = execution_date
        self.start_date = start_date
        self.external_trigger = external_trigger
        self.conf = conf or {}
        self.state = state
        self.run_type = run_type
        self.dag_hash = dag_hash
        self.creating_job_id = creating_job_id
        super().__init__()

    def __repr__(self):
        return (
            '<DagRun {dag_id} @ {execution_date}: {run_id}, externally triggered: {external_trigger}>'
        ).format(
            dag_id=self.dag_id,
            execution_date=self.execution_date,
            run_id=self.run_id,
            external_trigger=self.external_trigger,
        )

    def get_state(self):
        return self._state

    def set_state(self, state):
        if self._state != state:
            self._state = state
            self.end_date = timezone.utcnow() if self._state in State.finished else None

    @declared_attr
    def state(self):
        return synonym('_state', descriptor=property(self.get_state, self.set_state))

    @provide_session
    def refresh_from_db(self, session: Session = None):
        """
        Reloads the current dagrun from the database

        :param session: database session
        :type session: Session
        """
        DR = DagRun

        exec_date = func.cast(self.execution_date, DateTime)

        dr = (
            session.query(DR)
            .filter(
                DR.dag_id == self.dag_id,
                func.cast(DR.execution_date, DateTime) == exec_date,
                DR.run_id == self.run_id,
            )
            .one()
        )

        self.id = dr.id
        self.state = dr.state

    @classmethod
    def next_dagruns_to_examine(
        cls,
        session: Session,
        max_number: Optional[int] = None,
    ):
        """
        Return the next DagRuns that the scheduler should attempt to schedule.

        This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
        query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
        the transaction is committed it will be unlocked.

        :rtype: list[airflow.models.DagRun]
        """
        from airflow.models.dag import DagModel

        if max_number is None:
            max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

        # TODO: Bake this query, it is run _A lot_
        query = (
            session.query(cls)
            .filter(cls.state == State.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
            .join(
                DagModel,
                DagModel.dag_id == cls.dag_id,
            )
            .filter(
                DagModel.is_paused.is_(False),
                DagModel.is_active.is_(True),
            )
            .order_by(
                nulls_first(cls.last_scheduling_decision, session=session),
                cls.execution_date,
            )
        )

        if not settings.ALLOW_FUTURE_EXEC_DATES:
            query = query.filter(DagRun.execution_date <= func.now())

        return with_row_locks(query.limit(max_number), of=cls, **skip_locked(session=session))

    @staticmethod
    @provide_session
    def find(
        dag_id: Optional[Union[str, List[str]]] = None,
        run_id: Optional[str] = None,
        execution_date: Optional[datetime] = None,
        state: Optional[str] = None,
        external_trigger: Optional[bool] = None,
        no_backfills: bool = False,
        run_type: Optional[DagRunType] = None,
        session: Session = None,
        execution_start_date: Optional[datetime] = None,
        execution_end_date: Optional[datetime] = None,
    ) -> List["DagRun"]:
        """
        Returns a set of dag runs for the given search criteria.

        :param dag_id: the dag_id or list of dag_id to find dag runs for
        :type dag_id: str or list[str]
        :param run_id: defines the run id for this dag run
        :type run_id: str
        :param run_type: type of DagRun
        :type run_type: airflow.utils.types.DagRunType
        :param execution_date: the execution date
        :type execution_date: datetime.datetime or list[datetime.datetime]
        :param state: the state of the dag run
        :type state: str
        :param external_trigger: whether this dag run is externally triggered
        :type external_trigger: bool
        :param no_backfills: return no backfills (True), return all (False).
            Defaults to False
        :type no_backfills: bool
        :param session: database session
        :type session: sqlalchemy.orm.session.Session
        :param execution_start_date: dag run that was executed from this date
        :type execution_start_date: datetime.datetime
        :param execution_end_date: dag run that was executed until this date
        :type execution_end_date: datetime.datetime
        """
        DR = DagRun

        qry = session.query(DR)
        dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
        if dag_ids:
            qry = qry.filter(DR.dag_id.in_(dag_ids))
        if run_id:
            qry = qry.filter(DR.run_id == run_id)
        if execution_date:
            if isinstance(execution_date, list):
                qry = qry.filter(DR.execution_date.in_(execution_date))
            else:
                qry = qry.filter(DR.execution_date == execution_date)
        if execution_start_date and execution_end_date:
            qry = qry.filter(DR.execution_date.between(execution_start_date, execution_end_date))
        elif execution_start_date:
            qry = qry.filter(DR.execution_date >= execution_start_date)
        elif execution_end_date:
            qry = qry.filter(DR.execution_date <= execution_end_date)
        if state:
            qry = qry.filter(DR.state == state)
        if external_trigger is not None:
            qry = qry.filter(DR.external_trigger == external_trigger)
        if run_type:
            qry = qry.filter(DR.run_type == run_type)
        if no_backfills:
            qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB)

        dr = qry.order_by(DR.execution_date).all()

        return dr

    @staticmethod
    def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
        """Generate Run ID based on Run Type and Execution Date"""
        return f"{run_type}__{execution_date.isoformat()}"

    @provide_session
    def get_task_instances(self, state=None, session=None):
        """Returns the task instances for this dag run"""
        tis = session.query(TI).filter(
            TI.dag_id == self.dag_id,
            TI.execution_date == self.execution_date,
        )

        if state:
            if isinstance(state, str):
                tis = tis.filter(TI.state == state)
            else:
                # this is required to deal with NULL values
                if None in state:
                    if all(x is None for x in state):
                        tis = tis.filter(TI.state.is_(None))
                    else:
                        not_none_state = [s for s in state if s]
                        tis = tis.filter(or_(TI.state.in_(not_none_state), TI.state.is_(None)))
                else:
                    tis = tis.filter(TI.state.in_(state))

        if self.dag and self.dag.partial:
            tis = tis.filter(TI.task_id.in_(self.dag.task_ids))
        return tis.all()

    @provide_session
    def get_task_instance(self, task_id: str, session: Session = None):
        """
        Returns the task instance specified by task_id for this dag run

        :param task_id: the task id
        :type task_id: str
        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        ti = (
            session.query(TI)
            .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id)
            .first()
        )

        return ti

    def get_dag(self):
        """
        Returns the Dag associated with this DagRun.

        :return: DAG
        """
        if not self.dag:
            raise AirflowException(f"The DAG (.dag) for {self} needs to be set")

        return self.dag

    @provide_session
    def get_previous_dagrun(self, state: Optional[str] = None, session: Session = None) -> Optional['DagRun']:
        """The previous DagRun, if there is one"""
        filters = [
            DagRun.dag_id == self.dag_id,
            DagRun.execution_date < self.execution_date,
        ]
        if state is not None:
            filters.append(DagRun.state == state)
        return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()

    @provide_session
    def get_previous_scheduled_dagrun(self, session: Session = None):
        """The previous, SCHEDULED DagRun, if there is one"""
        dag = self.get_dag()

        return (
            session.query(DagRun)
            .filter(
                DagRun.dag_id == self.dag_id,
                DagRun.execution_date == dag.previous_schedule(self.execution_date),
            )
            .first()
        )

    @provide_session
    def update_state(
        self, session: Session = None, execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :type execute_callbacks: bool
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[callback_requests.DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm

        dag = self.get_dag()
        info = self.task_instance_scheduling_decisions(session)

        tis = info.tis
        schedulable_tis = info.schedulable_tis
        changed_tis = info.changed_tis
        finished_tasks = info.finished_tasks
        unfinished_tasks = info.unfinished_tasks

        none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
        none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)

        if unfinished_tasks and none_depends_on_past and none_task_concurrency:
            # small speed up
            are_runnable_tasks = (
                schedulable_tis
                or self._are_premature_tis(unfinished_tasks, finished_tasks, session)
                or changed_tis
            )

        duration = timezone.utcnow() - start_dttm
        Stats.timing(f"dagrun.dependency-check.{self.dag_id}", duration)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(
            leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis
        ):
            self.log.error('Marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self, success=False, reason='task_failure', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='task_failure',
                )

        # if all leafs succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(
            leaf_ti.state in {State.SUCCESS, State.SKIPPED} for leaf_ti in leaf_tis
        ):
            self.log.info('Marking run %s successful', self)
            self.set_state(State.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self, success=True, reason='success', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=False,
                    msg='success',
                )

        # if *all tasks* are deadlocked, the run failed
        elif unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks:
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked',
                )

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(State.RUNNING)

        self._emit_duration_stats_for_finished_state()

        session.merge(self)

        return schedulable_tis, callback

    @provide_session
    def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision:

        schedulable_tis: List[TI] = []
        changed_tis = False

        tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,)))
        self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
        for ti in tis:
            ti.task = self.get_dag().get_task(ti.task_id)

        unfinished_tasks = [t for t in tis if t.state in State.unfinished]
        finished_tasks = [t for t in tis if t.state in State.finished]
        if unfinished_tasks:
            scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
            self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks))
            schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)

        return TISchedulingDecision(
            tis=tis,
            schedulable_tis=schedulable_tis,
            changed_tis=changed_tis,
            unfinished_tasks=unfinished_tasks,
            finished_tasks=finished_tasks,
        )

    def _get_ready_tis(
        self,
        scheduleable_tasks: List[TI],
        finished_tasks: List[TI],
        session: Session,
    ) -> Tuple[List[TI], bool]:
        old_states = {}
        ready_tis: List[TI] = []
        changed_tis = False

        if not scheduleable_tasks:
            return ready_tis, changed_tis

        # Check dependencies
        for st in scheduleable_tasks:
            old_state = st.state
            if st.are_dependencies_met(
                dep_context=DepContext(flag_upstream_failed=True, finished_tasks=finished_tasks),
                session=session,
            ):
                ready_tis.append(st)
            else:
                old_states[st.key] = old_state

        # Check if any ti changed state
        tis_filter = TI.filter_for_tis(old_states.keys())
        if tis_filter is not None:
            fresh_tis = session.query(TI).filter(tis_filter).all()
            changed_tis = any(ti.state != old_states[ti.key] for ti in fresh_tis)

        return ready_tis, changed_tis

    def _are_premature_tis(
        self,
        unfinished_tasks: List[TI],
        finished_tasks: List[TI],
        session: Session,
    ) -> bool:
        # there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are
        # not ready yet so we set the flags to count them in
        for ut in unfinished_tasks:
            if ut.are_dependencies_met(
                dep_context=DepContext(
                    flag_upstream_failed=True,
                    ignore_in_retry_period=True,
                    ignore_in_reschedule_period=True,
                    finished_tasks=finished_tasks,
                ),
                session=session,
            ):
                return True
        return False

    def _emit_duration_stats_for_finished_state(self):
        if self.state == State.RUNNING:
            return

        duration = self.end_date - self.start_date
        if self.state is State.SUCCESS:
            Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration)
        elif self.state == State.FAILED:
            Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)

    @provide_session
    def verify_integrity(self, session: Session = None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info(
                'Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.'
            )
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()

    @staticmethod
    def get_run(session: Session, dag_id: str, execution_date: datetime):
        """
        Get a single DAG Run

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :param dag_id: DAG ID
        :type dag_id: unicode
        :param execution_date: execution date
        :type execution_date: datetime
        :return: DagRun corresponding to the given dag_id and execution date
            if one exists. None otherwise.
        :rtype: airflow.models.DagRun
        """
        qry = session.query(DagRun).filter(
            DagRun.dag_id == dag_id,
            DagRun.external_trigger == False,  # noqa pylint: disable=singleton-comparison
            DagRun.execution_date == execution_date,
        )
        return qry.first()

    @property
    def is_backfill(self):
        return self.run_type == DagRunType.BACKFILL_JOB

    @classmethod
    @provide_session
    def get_latest_runs(cls, session=None):
        """Returns the latest DagRun for each DAG"""
        subquery = (
            session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date'))
            .group_by(cls.dag_id)
            .subquery()
        )
        dagruns = (
            session.query(cls)
            .join(
                subquery,
                and_(cls.dag_id == subquery.c.dag_id, cls.execution_date == subquery.c.execution_date),
            )
            .all()
        )
        return dagruns

    @provide_session
    def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int:
        """
        Set the given task instances in to the scheduled state.

        Each element of ``schedulable_tis`` should have it's ``task`` attribute already set.

        Any DummyOperator without callbacks is instead set straight to the success state.

        All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
        is the caller's responsibility to call this function only with TIs from a single dag run.
        """
        # Get list of TIs that do not need to executed, these are
        # tasks using DummyOperator and without on_execute_callback / on_success_callback
        dummy_tis = {
            ti
            for ti in schedulable_tis
            if (
                ti.task.task_type == "DummyOperator"
                and not ti.task.on_execute_callback
                and not ti.task.on_success_callback
            )
        }

        schedulable_ti_ids = [ti.task_id for ti in schedulable_tis if ti not in dummy_tis]
        count = 0

        if schedulable_ti_ids:
            count += (
                session.query(TI)
                .filter(
                    TI.dag_id == self.dag_id,
                    TI.execution_date == self.execution_date,
                    TI.task_id.in_(schedulable_ti_ids),
                )
                .update({TI.state: State.SCHEDULED}, synchronize_session=False)
            )

        # Tasks using DummyOperator should not be executed, mark them as success
        if dummy_tis:
            count += (
                session.query(TI)
                .filter(
                    TI.dag_id == self.dag_id,
                    TI.execution_date == self.execution_date,
                    TI.task_id.in_(ti.task_id for ti in dummy_tis),
                )
                .update(
                    {
                        TI.state: State.SUCCESS,
                        TI.start_date: timezone.utcnow(),
                        TI.end_date: timezone.utcnow(),
                        TI.duration: 0,
                    },
                    synchronize_session=False,
                )
            )

        return count
 def __init__(self,
              shard_max=conf.getint('smart_sensor',
                                    'shard_code_upper_limit'),
              shard_min=0,
              **kwargs):
     super().__init__(shard_min=shard_min, shard_max=shard_max, **kwargs)
Exemple #20
0
def create_app(config=None, testing=False, app_name="Airflow"):
    global app, appbuilder
    app = Flask(__name__)
    app.secret_key = conf.get('webserver', 'SECRET_KEY')

    session_lifetime_days = conf.getint('webserver', 'SESSION_LIFETIME_DAYS', fallback=30)
    app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(days=session_lifetime_days)

    app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    app.config['APP_NAME'] = app_name
    app.config['TESTING'] = testing
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    app.config['SESSION_COOKIE_HTTPONLY'] = True
    app.config['SESSION_COOKIE_SECURE'] = conf.getboolean('webserver', 'COOKIE_SECURE')
    app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver', 'COOKIE_SAMESITE')

    if config:
        app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    app.json_encoder = AirflowJsonEncoder

    csrf.init_app(app)
    db = SQLA()
    db.session = settings.Session
    db.init_app(app)

    from airflow import api
    api.load_auth()
    api.API_AUTH.api_auth.init_app(app)

    Cache(app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})

    from airflow.www.blueprints import routes
    app.register_blueprint(routes)

    configure_logging()
    configure_manifest_files(app)

    with app.app_context():
        from airflow.www.security import AirflowSecurityManager
        security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \
            AirflowSecurityManager

        if not issubclass(security_manager_class, AirflowSecurityManager):
            raise Exception(
                """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager,
                 not FAB's security manager.""")

        class AirflowAppBuilder(AppBuilder):

            def _check_and_init(self, baseview):
                if hasattr(baseview, 'datamodel'):
                    # Delete sessions if initiated previously to limit side effects. We want to use
                    # the current session in the current application.
                    baseview.datamodel.session = None
                return super()._check_and_init(baseview)

        appbuilder = AirflowAppBuilder(
            app=app,
            session=settings.Session,
            security_manager_class=security_manager_class,
            base_template='airflow/master.html',
            update_perms=conf.getboolean('webserver', 'UPDATE_FAB_PERMS'))

        def init_views(appbuilder):
            from airflow.www import views
            # Remove the session from scoped_session registry to avoid
            # reusing a session with a disconnected connection
            appbuilder.session.remove()
            appbuilder.add_view_no_menu(views.Airflow())
            appbuilder.add_view_no_menu(views.DagModelView())
            appbuilder.add_view(views.DagRunModelView,
                                "DAG Runs",
                                category="Browse",
                                category_icon="fa-globe")
            appbuilder.add_view(views.JobModelView,
                                "Jobs",
                                category="Browse")
            appbuilder.add_view(views.LogModelView,
                                "Logs",
                                category="Browse")
            appbuilder.add_view(views.SlaMissModelView,
                                "SLA Misses",
                                category="Browse")
            appbuilder.add_view(views.TaskInstanceModelView,
                                "Task Instances",
                                category="Browse")
            appbuilder.add_view(views.ConfigurationView,
                                "Configurations",
                                category="Admin",
                                category_icon="fa-user")
            appbuilder.add_view(views.ConnectionModelView,
                                "Connections",
                                category="Admin")
            appbuilder.add_view(views.PoolModelView,
                                "Pools",
                                category="Admin")
            appbuilder.add_view(views.VariableModelView,
                                "Variables",
                                category="Admin")
            appbuilder.add_view(views.XComModelView,
                                "XComs",
                                category="Admin")

            if "dev" in version.version:
                airflow_doc_site = "https://airflow.readthedocs.io/en/latest"
            else:
                airflow_doc_site = 'https://airflow.apache.org/docs/{}'.format(version.version)

            appbuilder.add_link("Website",
                                href='https://airflow.apache.org',
                                category="Docs",
                                category_icon="fa-globe")
            appbuilder.add_link("Documentation",
                                href=airflow_doc_site,
                                category="Docs",
                                category_icon="fa-cube")
            appbuilder.add_link("GitHub",
                                href='https://github.com/apache/airflow',
                                category="Docs")
            appbuilder.add_view(views.VersionView,
                                'Version',
                                category='About',
                                category_icon='fa-th')

            def integrate_plugins():
                """Integrate plugins to the context"""
                from airflow import plugins_manager

                plugins_manager.initialize_web_ui_plugins()

                for v in plugins_manager.flask_appbuilder_views:
                    log.debug("Adding view %s", v["name"])
                    appbuilder.add_view(v["view"],
                                        v["name"],
                                        category=v["category"])
                for ml in sorted(plugins_manager.flask_appbuilder_menu_links, key=lambda x: x["name"]):
                    log.debug("Adding menu link %s", ml["name"])
                    appbuilder.add_link(ml["name"],
                                        href=ml["href"],
                                        category=ml["category"],
                                        category_icon=ml["category_icon"])

            integrate_plugins()
            # Garbage collect old permissions/views after they have been modified.
            # Otherwise, when the name of a view or menu is changed, the framework
            # will add the new Views and Menus names to the backend, but will not
            # delete the old ones.

        def init_plugin_blueprints(app):
            from airflow.plugins_manager import flask_blueprints

            for bp in flask_blueprints:
                log.debug("Adding blueprint %s:%s", bp["name"], bp["blueprint"].import_name)
                app.register_blueprint(bp["blueprint"])

        def init_error_handlers(app: Flask):
            from airflow.www import views
            app.register_error_handler(500, views.show_traceback)
            app.register_error_handler(404, views.circles)

        init_views(appbuilder)
        init_plugin_blueprints(app)
        init_error_handlers(app)

        if conf.getboolean('webserver', 'UPDATE_FAB_PERMS'):
            security_manager = appbuilder.sm
            security_manager.sync_roles()

        from airflow.www.api.experimental import endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            import importlib
            importlib.reload(e)

        app.register_blueprint(e.api_experimental, url_prefix='/api/experimental')

        server_timezone = conf.get('core', 'default_timezone')
        if server_timezone == "system":
            server_timezone = pendulum.local_timezone().name
        elif server_timezone == "utc":
            server_timezone = "UTC"

        default_ui_timezone = conf.get('webserver', 'default_ui_timezone')
        if default_ui_timezone == "system":
            default_ui_timezone = pendulum.local_timezone().name
        elif default_ui_timezone == "utc":
            default_ui_timezone = "UTC"
        if not default_ui_timezone:
            default_ui_timezone = server_timezone

        @app.context_processor
        def jinja_globals():  # pylint: disable=unused-variable

            globals = {
                'server_timezone': server_timezone,
                'default_ui_timezone': default_ui_timezone,
                'hostname': socket.getfqdn() if conf.getboolean(
                    'webserver', 'EXPOSE_HOSTNAME', fallback=True) else 'redact',
                'navbar_color': conf.get(
                    'webserver', 'NAVBAR_COLOR'),
                'log_fetch_delay_sec': conf.getint(
                    'webserver', 'log_fetch_delay_sec', fallback=2),
                'log_auto_tailing_offset': conf.getint(
                    'webserver', 'log_auto_tailing_offset', fallback=30),
                'log_animation_speed': conf.getint(
                    'webserver', 'log_animation_speed', fallback=1000)
            }

            if 'analytics_tool' in conf.getsection('webserver'):
                globals.update({
                    'analytics_tool': conf.get('webserver', 'ANALYTICS_TOOL'),
                    'analytics_id': conf.get('webserver', 'ANALYTICS_ID')
                })

            return globals

        @app.before_request
        def before_request():
            _force_log_out_after = conf.getint('webserver', 'FORCE_LOG_OUT_AFTER', fallback=0)
            if _force_log_out_after > 0:
                flask.session.permanent = True
                app.permanent_session_lifetime = datetime.timedelta(minutes=_force_log_out_after)
                flask.session.modified = True
                flask.g.user = flask_login.current_user

        @app.after_request
        def apply_caching(response):
            _x_frame_enabled = conf.getboolean('webserver', 'X_FRAME_ENABLED', fallback=True)
            if not _x_frame_enabled:
                response.headers["X-Frame-Options"] = "DENY"
            return response

        @app.before_request
        def make_session_permanent():
            flask_session.permanent = True

    return app, appbuilder
Exemple #21
0
    # Ensure we close DB connections at scheduler and gunicon worker terminations
    atexit.register(dispose_orm)


# pylint: enable=global-statement

# Const stuff

KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'}

# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_update_interval', fallback=30)

# Fetching serialized DAG can not be faster than a minimum interval to reduce database
# read rate. This config controls when your DAGs are updated in the Webserver
MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_fetch_interval', fallback=10)

# Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
# from DB instead of trying to access files in a DAG folder.
STORE_DAG_CODE = conf.getboolean("core", "store_dag_code", fallback=True)

# If donot_modify_handlers=True, we do not modify logging handlers in task_run command
# If the flag is set to False, we remove all handlers from the root logger
# and add all handlers from 'airflow.task' logger to the root Logger. This is done
# to get all the logs from the print & log statements in the DAG files before a task is run
# The handlers are restored after the task completes execution.
Exemple #22
0
    def _read(self, ti, try_number, metadata=None):
        """
        Template method that contains custom logic of reading
        logs given the try_number.

        :param ti: task instance record
        :param try_number: current try_number to read log from
        :param metadata: log metadata,
                         can be used for steaming log reading and auto-tailing.
        :return: log message as a string and metadata.
        """
        # Task instance here might be different from task instance when
        # initializing the handler. Thus explicitly getting log location
        # is needed to get correct log path.
        log_relative_path = self._render_filename(ti, try_number)
        location = os.path.join(self.local_base, log_relative_path)

        log = ""

        if os.path.exists(location):
            try:
                with open(location, encoding="utf-8",
                          errors="surrogateescape") as file:
                    log += f"*** Reading local file: {location}\n"
                    log += "".join(file.readlines())
            except Exception as e:
                log = f"*** Failed to load local log file: {location}\n"
                log += f"*** {str(e)}\n"
        elif conf.get('core', 'executor') == 'KubernetesExecutor':
            try:
                from airflow.kubernetes.kube_client import get_kube_client

                kube_client = get_kube_client()

                if len(ti.hostname) >= 63:
                    # Kubernetes takes the pod name and truncates it for the hostname. This truncated hostname
                    # is returned for the fqdn to comply with the 63 character limit imposed by DNS standards
                    # on any label of a FQDN.
                    pod_list = kube_client.list_namespaced_pod(
                        conf.get('kubernetes', 'namespace'))
                    matches = [
                        pod.metadata.name for pod in pod_list.items
                        if pod.metadata.name.startswith(ti.hostname)
                    ]
                    if len(matches) == 1:
                        if len(matches[0]) > len(ti.hostname):
                            ti.hostname = matches[0]

                log += f'*** Trying to get logs (last 100 lines) from worker pod {ti.hostname} ***\n\n'

                res = kube_client.read_namespaced_pod_log(
                    name=ti.hostname,
                    namespace=conf.get('kubernetes', 'namespace'),
                    container='base',
                    follow=False,
                    tail_lines=100,
                    _preload_content=False,
                )

                for line in res:
                    log += line.decode()

            except Exception as f:
                log += f'*** Unable to fetch logs from worker pod {ti.hostname} ***\n{str(f)}\n\n'
        else:
            import httpx

            url = os.path.join(
                "http://{ti.hostname}:{worker_log_server_port}/log",
                log_relative_path).format(ti=ti,
                                          worker_log_server_port=conf.get(
                                              'logging',
                                              'WORKER_LOG_SERVER_PORT'))
            log += f"*** Log file does not exist: {location}\n"
            log += f"*** Fetching from: {url}\n"
            try:
                timeout = None  # No timeout
                try:
                    timeout = conf.getint('webserver', 'log_fetch_timeout_sec')
                except (AirflowConfigException, ValueError):
                    pass

                signer = TimedJSONWebSignatureSerializer(
                    secret_key=conf.get('webserver', 'secret_key'),
                    algorithm_name='HS512',
                    expires_in=conf.getint('webserver',
                                           'log_request_clock_grace',
                                           fallback=30),
                    # This isn't really a "salt", more of a signing context
                    salt='task-instance-logs',
                )

                response = httpx.get(
                    url,
                    timeout=timeout,
                    headers={'Authorization': signer.dumps(log_relative_path)})
                response.encoding = "utf-8"

                if response.status_code == 403:
                    log += (
                        "*** !!!! Please make sure that all your Airflow components (e.g. "
                        "schedulers, webservers and workers) have "
                        "the same 'secret_key' configured in 'webserver' section and "
                        "time is synchronized on all your machines (for example with ntpd) !!!!!\n***"
                    )
                    log += (
                        "*** See more at https://airflow.apache.org/docs/apache-airflow/"
                        "stable/configurations-ref.html#secret-key\n***")
                # Check if the resource was properly fetched
                response.raise_for_status()

                log += '\n' + response.text
            except Exception as e:
                log += f"*** Failed to fetch log file from worker. {str(e)}\n"

        return log, {'end_of_log': True}
from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow

log = logging.getLogger(__name__)

# Make it constant for unit test.
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'

CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'

OPERATION_TIMEOUT = conf.getint('celery', 'operation_timeout', fallback=2)
'''
To start the celery worker, run the command:
airflow celery worker
'''

if conf.has_option('celery', 'celery_config_options'):
    celery_configuration = conf.getimport('celery', 'celery_config_options')
else:
    celery_configuration = DEFAULT_CELERY_CONFIG

app = Celery(conf.get('celery', 'CELERY_APP_NAME'),
             config_source=celery_configuration)


@app.task
Exemple #24
0
    def __init__(
        self,
        task_id: str,
        owner: str = conf.get('operators', 'DEFAULT_OWNER'),
        email: Optional[Union[str, Iterable[str]]] = None,
        email_on_retry: bool = True,
        email_on_failure: bool = True,
        retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0),
        retry_delay: timedelta = timedelta(seconds=300),
        retry_exponential_backoff: bool = False,
        max_retry_delay: Optional[datetime] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        depends_on_past: bool = False,
        wait_for_downstream: bool = False,
        dag=None,
        params: Optional[Dict] = None,
        default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
        priority_weight: int = 1,
        weight_rule: str = WeightRule.DOWNSTREAM,
        queue: str = conf.get('celery', 'default_queue'),
        pool: str = Pool.DEFAULT_POOL_NAME,
        sla: Optional[timedelta] = None,
        execution_timeout: Optional[timedelta] = None,
        on_execute_callback: Optional[Callable] = None,
        on_failure_callback: Optional[Callable] = None,
        on_success_callback: Optional[Callable] = None,
        on_retry_callback: Optional[Callable] = None,
        trigger_rule: str = TriggerRule.ALL_SUCCESS,
        resources: Optional[Dict] = None,
        run_as_user: Optional[str] = None,
        task_concurrency: Optional[int] = None,
        executor_config: Optional[Dict] = None,
        do_xcom_push: bool = True,
        inlets: Optional[Any] = None,
        outlets: Optional[Any] = None,
        *args,
        **kwargs
    ):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(
                    c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3
            )
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_triggers=TriggerRule.all_triggers(),
                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_execute_callback = on_execute_callback
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_weight_rules=WeightRule.all_weight_rules,
                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(**resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        from airflow.models.dag import DAG
        self.subdag: Optional[DAG] = None

        self._log = logging.getLogger("airflow.task.operators")

        # Lineage
        self.inlets: List = []
        self.outlets: List = []

        self._inlets: List = []
        self._outlets: List = []

        if inlets:
            self._inlets = inlets if isinstance(inlets, list) else [inlets, ]

        if outlets:
            self._outlets = outlets if isinstance(outlets, list) else [outlets, ]
Exemple #25
0
from airflow import executors, models, settings, utils
from airflow.configuration import conf
from airflow.utils import AirflowException, State


Base = models.Base
ID_LEN = models.ID_LEN

# Setting up a statsd client if needed
statsd = None
if conf.get('scheduler', 'statsd_on'):
    from statsd import StatsClient
    statsd = StatsClient(
        host=conf.get('scheduler', 'statsd_host'),
        port=conf.getint('scheduler', 'statsd_port'),
        prefix=conf.get('scheduler', 'statsd_prefix'))


class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """

    __tablename__ = "job"

    id = Column(Integer, primary_key=True)
    dag_id = Column(String(ID_LEN),)
Exemple #26
0
    def gauge(self, stat, value, rate=1, delta=False):
        if self.allow_list_validator.test(stat):
            return self.statsd.gauge(stat, value, rate, delta)

    def timing(self, stat, dt):
        if self.allow_list_validator.test(stat):
            return self.statsd.timing(stat, dt)


Stats = DummyStatsLogger  # type: Any

if conf.getboolean('scheduler', 'statsd_on'):
    from statsd import StatsClient

    statsd = StatsClient(host=conf.get('scheduler', 'statsd_host'),
                         port=conf.getint('scheduler', 'statsd_port'),
                         prefix=conf.get('scheduler', 'statsd_prefix'))

    allow_list_validator = AllowListValidator(
        conf.get('scheduler', 'statsd_allow_list', fallback=None))

    Stats = SafeStatsdLogger(statsd, allow_list_validator)
else:
    Stats = DummyStatsLogger

HEADER = '\n'.join([
    r'  ____________       _____________',
    r' ____    |__( )_________  __/__  /________      __',
    r'____  /| |_  /__  ___/_  /_ __  /_  __ \_ | /| / /',
    r'___  ___ |  / _  /   _  __/ _  / / /_/ /_ |/ |/ /',
    r' _/_/  |_/_/  /_/    /_/    /_/  \____/____/|__/',
Exemple #27
0
import re
import signal
import subprocess
from datetime import datetime
from functools import reduce

import psutil
from jinja2 import Template

from airflow.configuration import conf
from airflow.exceptions import AirflowException

# When killing processes, time to wait after issuing a SIGTERM before issuing a
# SIGKILL.
DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint(
    'core', 'KILLED_TASK_CLEANUP_TIME'
)

KEY_REGEX = re.compile(r'^[\w.-]+$')


def validate_key(k, max_length=250):
    """
    Validates value used as a key.
    """
    if not isinstance(k, str):
        raise TypeError("The key has to be a string")
    elif len(k) > max_length:
        raise AirflowException(
            "The key has to be less than {0} characters".format(max_length))
    elif not KEY_REGEX.match(k):
Exemple #28
0
log = logging.getLogger(__name__)

broker_url = conf.get('celery', 'BROKER_URL')

broker_transport_options = conf.getsection(
    'celery_broker_transport_options') or {}
if 'visibility_timeout' not in broker_transport_options:
    if _broker_supports_visibility_timeout(broker_url):
        broker_transport_options['visibility_timeout'] = 21600

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer':
    'json',
    'worker_prefetch_multiplier':
    conf.getint('celery', 'worker_prefetch_multiplier'),
    'task_acks_late':
    True,
    'task_default_queue':
    conf.get('operators', 'DEFAULT_QUEUE'),
    'task_default_exchange':
    conf.get('operators', 'DEFAULT_QUEUE'),
    'task_track_started':
    conf.getboolean('celery', 'task_track_started'),
    'broker_url':
    broker_url,
    'broker_transport_options':
    broker_transport_options,
    'result_backend':
    conf.get('celery', 'RESULT_BACKEND'),
    'worker_concurrency':
Exemple #29
0
    configure_vars()
    prepare_syspath()
    import_local_settings()
    global LOGGING_CLASS_PATH
    LOGGING_CLASS_PATH = configure_logging()
    configure_adapters()
    # The webservers import this file from models.py with the default settings.
    configure_orm()
    configure_action_logging()

    # Ensure we close DB connections at scheduler and gunicon worker terminations
    atexit.register(dispose_orm)


# Const stuff

KILOBYTE = 1024
MEGABYTE = KILOBYTE * KILOBYTE
WEB_COLORS = {'LIGHTBLUE': '#4d9de0', 'LIGHTORANGE': '#FF9933'}

# If store_serialized_dags is True, scheduler writes serialized DAGs to DB, and webserver
# reads DAGs from DB instead of importing from files.
STORE_SERIALIZED_DAGS = conf.getboolean('core',
                                        'store_serialized_dags',
                                        fallback=False)

# Updating serialized DAG can not be faster than a minimum interval to reduce database
# write rate.
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
    'core', 'min_serialized_dag_update_interval', fallback=30)
Exemple #30
0
                action="store_true")
ARG_ACCESS_LOGFILE = Arg(
    ("-A", "--access-logfile"),
    default=conf.get('webserver', 'ACCESS_LOGFILE'),
    help="The logfile to store the webserver access log. Use '-' to print to "
    "stderr")
ARG_ERROR_LOGFILE = Arg(
    ("-E", "--error-logfile"),
    default=conf.get('webserver', 'ERROR_LOGFILE'),
    help="The logfile to store the webserver error log. Use '-' to print to "
    "stderr")

# scheduler
ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag to run")
ARG_NUM_RUNS = Arg(("-n", "--num-runs"),
                   default=conf.getint('scheduler', 'num_runs'),
                   type=int,
                   help="Set the number of runs to execute before exiting")

# worker
ARG_DO_PICKLE = Arg(
    ("-p", "--do-pickle"),
    default=False,
    help=("Attempt to pickle the DAG object to send over "
          "to the workers, instead of letting workers run their version "
          "of the code"),
    action="store_true")
ARG_QUEUES = Arg(("-q", "--queues"),
                 help="Comma delimited list of queues to serve",
                 default=conf.get('celery', 'DEFAULT_QUEUE'))
ARG_CONCURRENCY = Arg(("-c", "--concurrency"),
Exemple #31
0
class DagBag(BaseDagBag, LoggingMixin):
    """
    A dagbag is a collection of dags, parsed out of a folder tree and has high
    level configuration settings, like what database to use as a backend and
    what executor to use to fire off tasks. This makes it easier to run
    distinct environments for say production and development, tests, or for
    different teams or security profiles. What would have been system level
    settings are now dagbag level so that one system can run multiple,
    independent settings sets.

    :param dag_folder: the folder to scan to find DAGs
    :type dag_folder: unicode
    :param executor: the executor to use when executing task instances
        in this DagBag
    :param include_examples: whether to include the examples that ship
        with airflow or not
    :type include_examples: bool
    :param has_logged: an instance boolean that gets flipped from False to True after a
        file has been skipped. This is to prevent overloading the user with logging
        messages about skipped files. Therefore only once per DagBag is a file logged
        being skipped.
    :param store_serialized_dags: Read DAGs from DB if store_serialized_dags is ``True``.
        If ``False`` DAGs are read from python files.
    :type store_serialized_dags: bool
    """

    # static class variables to detetct dag cycle
    CYCLE_NEW = 0
    CYCLE_IN_PROGRESS = 1
    CYCLE_DONE = 2
    DAGBAG_IMPORT_TIMEOUT = conf.getint('core', 'DAGBAG_IMPORT_TIMEOUT')
    SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint(
        'scheduler', 'scheduler_zombie_task_threshold')

    def __init__(
        self,
        dag_folder=None,
        include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
        safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
        store_serialized_dags=False,
    ):

        dag_folder = dag_folder or settings.DAGS_FOLDER
        self.dag_folder = dag_folder
        self.dags = {}
        # the file's last modified timestamp when we last read it
        self.file_last_changed = {}
        self.import_errors = {}
        self.has_logged = False
        self.store_serialized_dags = store_serialized_dags

        self.collect_dags(dag_folder=dag_folder,
                          include_examples=include_examples,
                          safe_mode=safe_mode)

    def size(self):
        """
        :return: the amount of dags contained in this dagbag
        """
        return len(self.dags)

    @property
    def dag_ids(self) -> List[str]:
        return self.dags.keys()

    def get_dag(self, dag_id):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        # Only read DAGs from DB if this dagbag is store_serialized_dags.
        if self.store_serialized_dags:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                row = SerializedDagModel.get(dag_id)
                if not row:
                    return None

                dag = row.dag
                for subdag in dag.subdags:
                    self.dags[subdag.dag_id] = subdag
                self.dags[dag.dag_id] = dag

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # Needs to load from file for a store_serialized_dags dagbag.
        enforce_from_file = False
        if self.store_serialized_dags and dag is not None:
            from airflow.serialization.serialized_objects import SerializedDAG
            enforce_from_file = isinstance(dag, SerializedDAG)

        # If the dag corresponding to root_dag_id is absent or expired
        orm_dag = DagModel.get_current(root_dag_id)
        if (orm_dag and
            (root_dag_id not in self.dags or
             (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired))
            ) or enforce_from_file:
            # Reprocess source file
            found_dags = self.process_file(filepath=correct_maybe_zipped(
                orm_dag.fileloc),
                                           only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [
                    found_dag.dag_id for found_dag in found_dags
            ]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)

    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        from airflow.models.dag import DAG  # Avoid circular import

        integrate_dag_plugins()
        found_dags = []

        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return found_dags

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.file_last_changed \
                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
                return found_dags

        except Exception as e:
            self.log.exception(e)
            return found_dags

        mods = []
        is_zipfile = zipfile.is_zipfile(filepath)
        if not is_zipfile:
            if safe_mode:
                with open(filepath, 'rb') as file:
                    content = file.read()
                    if not all([s in content for s in (b'DAG', b'airflow')]):
                        self.file_last_changed[
                            filepath] = file_last_changed_on_disk
                        # Don't want to spam user with skip messages
                        if not self.has_logged:
                            self.has_logged = True
                            self.log.info(
                                "File %s assumed to contain no DAGs. Skipping.",
                                filepath)
                        return found_dags

            self.log.debug("Importing %s", filepath)
            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
            mod_name = ('unusual_prefix_' +
                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
                        '_' + org_mod_name)

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            with timeout(self.DAGBAG_IMPORT_TIMEOUT):
                try:
                    loader = importlib.machinery.SourceFileLoader(
                        mod_name, filepath)
                    spec = importlib.util.spec_from_loader(mod_name, loader)
                    m = importlib.util.module_from_spec(spec)
                    sys.modules[spec.name] = m
                    loader.exec_module(m)
                    mods.append(m)
                except Exception as e:
                    self.log.exception("Failed to import: %s", filepath)
                    self.import_errors[filepath] = str(e)
                    self.file_last_changed[
                        filepath] = file_last_changed_on_disk

        else:
            zip_file = zipfile.ZipFile(filepath)
            for mod in zip_file.infolist():
                head, _ = os.path.split(mod.filename)
                mod_name, ext = os.path.splitext(mod.filename)
                if not head and (ext == '.py' or ext == '.pyc'):
                    if mod_name == '__init__':
                        self.log.warning("Found __init__.%s at root of %s",
                                         ext, filepath)
                    if safe_mode:
                        with zip_file.open(mod.filename) as zf:
                            self.log.debug("Reading %s from %s", mod.filename,
                                           filepath)
                            content = zf.read()
                            if not all(
                                [s in content for s in (b'DAG', b'airflow')]):
                                self.file_last_changed[filepath] = (
                                    file_last_changed_on_disk)
                                # todo: create ignore list
                                # Don't want to spam user with skip messages
                                if not self.has_logged:
                                    self.has_logged = True
                                    self.log.info(
                                        "File %s assumed to contain no DAGs. Skipping.",
                                        filepath)

                    if mod_name in sys.modules:
                        del sys.modules[mod_name]

                    try:
                        sys.path.insert(0, filepath)
                        m = importlib.import_module(mod_name)
                        mods.append(m)
                    except Exception as e:
                        self.log.exception("Failed to import: %s", filepath)
                        self.import_errors[filepath] = str(e)
                        self.file_last_changed[
                            filepath] = file_last_changed_on_disk

        for m in mods:
            for dag in list(m.__dict__.values()):
                if isinstance(dag, DAG):
                    if not dag.full_filepath:
                        dag.full_filepath = filepath
                        if dag.fileloc != filepath and not is_zipfile:
                            dag.fileloc = filepath
                    try:
                        dag.is_subdag = False
                        self.bag_dag(dag, parent_dag=dag, root_dag=dag)
                        if isinstance(dag._schedule_interval, str):
                            croniter(dag._schedule_interval)
                        found_dags.append(dag)
                        found_dags += dag.subdags
                    except (CroniterBadCronError, CroniterBadDateError,
                            CroniterNotAlphaError) as cron_e:
                        self.log.exception("Failed to bag_dag: %s",
                                           dag.full_filepath)
                        self.import_errors[dag.full_filepath] = \
                            "Invalid Cron expression: " + str(cron_e)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk
                    except AirflowDagCycleException as cycle_exception:
                        self.log.exception("Failed to bag_dag: %s",
                                           dag.full_filepath)
                        self.import_errors[dag.full_filepath] = str(
                            cycle_exception)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags

    def bag_dag(self, dag, parent_dag, root_dag):
        """
        Adds the DAG into the bag, recurses into sub dags.
        Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
        """

        dag.test_cycle()  # throws if a task cycle is found

        dag.resolve_template_files()
        dag.last_loaded = timezone.utcnow()

        for task in dag.tasks:
            settings.policy(task)

        subdags = dag.subdags

        try:
            for subdag in subdags:
                subdag.full_filepath = dag.full_filepath
                subdag.parent_dag = dag
                subdag.is_subdag = True
                self.bag_dag(subdag, parent_dag=dag, root_dag=root_dag)

            self.dags[dag.dag_id] = dag
            self.log.debug('Loaded DAG %s', dag)
        except AirflowDagCycleException as cycle_exception:
            # There was an error in bagging the dag. Remove it from the list of dags
            self.log.exception('Exception bagging dag: %s', dag.dag_id)
            # Only necessary at the root level since DAG.subdags automatically
            # performs DFS to search through all subdags
            if dag == root_dag:
                for subdag in subdags:
                    if subdag.dag_id in self.dags:
                        del self.dags[subdag.dag_id]
            raise cycle_exception

    def collect_dags(self,
                     dag_folder=None,
                     only_if_updated=True,
                     include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
                     safe_mode=conf.getboolean('core',
                                               'DAG_DISCOVERY_SAFE_MODE')):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        if self.store_serialized_dags:
            return

        self.log.info("Filling up the DagBag from %s", dag_folder)
        start_dttm = timezone.utcnow()
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []

        from airflow.utils.file import correct_maybe_zipped, list_py_file_paths
        dag_folder = correct_maybe_zipped(dag_folder)
        for filepath in list_py_file_paths(dag_folder,
                                           safe_mode=safe_mode,
                                           include_examples=include_examples):
            try:
                ts = timezone.utcnow()
                found_dags = self.process_file(filepath,
                                               only_if_updated=only_if_updated,
                                               safe_mode=safe_mode)
                dag_ids = [dag.dag_id for dag in found_dags]
                dag_id_names = str(dag_ids)

                td = timezone.utcnow() - ts
                stats.append(
                    FileLoadStat(
                        filepath.replace(settings.DAGS_FOLDER, ''),
                        td,
                        len(found_dags),
                        sum([len(dag.tasks) for dag in found_dags]),
                        dag_id_names,
                    ))
            except Exception as e:
                self.log.exception(e)
        Stats.gauge('collect_dags',
                    (timezone.utcnow() - start_dttm).total_seconds(), 1)
        Stats.gauge('dagbag_size', len(self.dags), 1)
        Stats.gauge('dagbag_import_errors', len(self.import_errors), 1)
        self.dagbag_stats = sorted(stats,
                                   key=lambda x: x.duration,
                                   reverse=True)
        for file_stat in self.dagbag_stats:
            # file_stat.file similar format: /subdir/dag_name.py
            # TODO: Remove for Airflow 2.0
            filename = file_stat.file.split('/')[-1].replace('.py', '')
            Stats.timing('dag.loading-duration.{}'.format(filename),
                         file_stat.duration)

    def collect_dags_from_db(self):
        """Collects DAGs from database."""
        from airflow.models.serialized_dag import SerializedDagModel
        start_dttm = timezone.utcnow()
        self.log.info("Filling up the DagBag from database")

        # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
        # from the table by the scheduler job.
        self.dags = SerializedDagModel.read_all_dags()

        # Adds subdags.
        # DAG post-processing steps such as self.bag_dag and croniter are not needed as
        # they are done by scheduler before serialization.
        subdags = {}
        for dag in self.dags.values():
            for subdag in dag.subdags:
                subdags[subdag.dag_id] = subdag
        self.dags.update(subdags)

        Stats.timing('collect_db_dags', timezone.utcnow() - start_dttm)

    def dagbag_report(self):
        """Prints a report around DagBag loading stats"""
        report = textwrap.dedent("""\n
        -------------------------------------------------------------------
        DagBag loading stats for {dag_folder}
        -------------------------------------------------------------------
        Number of DAGs: {dag_num}
        Total task number: {task_num}
        DagBag parsing time: {duration}
        {table}
        """)
        stats = self.dagbag_stats
        return report.format(
            dag_folder=self.dag_folder,
            duration=sum([o.duration for o in stats],
                         timedelta()).total_seconds(),
            dag_num=sum([o.dag_num for o in stats]),
            task_num=sum([o.task_num for o in stats]),
            table=tabulate(stats, headers="keys"),
        )

    def sync_to_db(self):
        """
        Save attributes about list of DAG to the DB.
        """
        from airflow.models.dag import DAG
        DAG.bulk_sync_to_db(self.dags.values())
Exemple #32
0
def prepare_engine_args(disable_connection_pool=False):
    """Prepare SQLAlchemy engine args"""
    engine_args = {}
    pool_connections = conf.getboolean('core', 'SQL_ALCHEMY_POOL_ENABLED')
    if disable_connection_pool or not pool_connections:
        engine_args['poolclass'] = NullPool
        log.debug("settings.prepare_engine_args(): Using NullPool")
    elif not SQL_ALCHEMY_CONN.startswith('sqlite'):
        # Pool size engine args not supported by sqlite.
        # If no config value is defined for the pool size, select a reasonable value.
        # 0 means no limit, which could lead to exceeding the Database connection limit.
        pool_size = conf.getint('core', 'SQL_ALCHEMY_POOL_SIZE', fallback=5)

        # The maximum overflow size of the pool.
        # When the number of checked-out connections reaches the size set in pool_size,
        # additional connections will be returned up to this limit.
        # When those additional connections are returned to the pool, they are disconnected and discarded.
        # It follows then that the total number of simultaneous connections
        # the pool will allow is pool_size + max_overflow,
        # and the total number of “sleeping” connections the pool will allow is pool_size.
        # max_overflow can be set to -1 to indicate no overflow limit;
        # no limit will be placed on the total number
        # of concurrent connections. Defaults to 10.
        max_overflow = conf.getint('core',
                                   'SQL_ALCHEMY_MAX_OVERFLOW',
                                   fallback=10)

        # The DB server already has a value for wait_timeout (number of seconds after
        # which an idle sleeping connection should be killed). Since other DBs may
        # co-exist on the same server, SQLAlchemy should set its
        # pool_recycle to an equal or smaller value.
        pool_recycle = conf.getint('core',
                                   'SQL_ALCHEMY_POOL_RECYCLE',
                                   fallback=1800)

        # Check connection at the start of each connection pool checkout.
        # Typically, this is a simple statement like “SELECT 1”, but may also make use
        # of some DBAPI-specific method to test the connection for liveness.
        # More information here:
        # https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
        pool_pre_ping = conf.getboolean('core',
                                        'SQL_ALCHEMY_POOL_PRE_PING',
                                        fallback=True)

        log.debug(
            "settings.prepare_engine_args(): Using pool settings. pool_size=%d, max_overflow=%d, "
            "pool_recycle=%d, pid=%d",
            pool_size,
            max_overflow,
            pool_recycle,
            os.getpid(),
        )
        engine_args['pool_size'] = pool_size
        engine_args['pool_recycle'] = pool_recycle
        engine_args['pool_pre_ping'] = pool_pre_ping
        engine_args['max_overflow'] = max_overflow

    # The default isolation level for MySQL (REPEATABLE READ) can introduce inconsistencies when
    # running multiple schedulers, as repeated queries on the same session may read from stale snapshots.
    # 'READ COMMITTED' is the default value for PostgreSQL.
    # More information here:
    # https://dev.mysql.com/doc/refman/8.0/en/innodb-transaction-isolation-levels.html"
    if SQL_ALCHEMY_CONN.startswith('mysql'):
        engine_args['isolation_level'] = 'READ COMMITTED'

    return engine_args
Exemple #33
0
# under the License.

from builtins import range
from collections import OrderedDict
import copy
# To avoid circular imports
import airflow.utils.dag_processing
from airflow.configuration import conf
from notification_service.client import NotificationClient
from airflow.settings import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
from airflow.models.event import TaskStatusEvent, TaskInstanceHelper


PARALLELISM = conf.getint('core', 'PARALLELISM')


class BaseExecutor(LoggingMixin):

    def __init__(self, parallelism=PARALLELISM):
        """
        Class to derive in order to interface with executor-type systems
        like Celery, Mesos, Yarn and the likes.

        :param parallelism: how many jobs should run at one time. Set to
            ``0`` for infinity
        :type parallelism: int
        """
        self.parallelism = parallelism
        self.queued_tasks = OrderedDict()
Exemple #34
0
def create_app():
    flask_app = Flask(__name__, static_folder=None)
    expiration_time_in_seconds = conf.getint('webserver',
                                             'log_request_clock_grace',
                                             fallback=30)
    log_directory = os.path.expanduser(conf.get('logging', 'BASE_LOG_FOLDER'))

    signer = JWTSigner(
        secret_key=conf.get('webserver', 'secret_key'),
        expiration_time_in_seconds=expiration_time_in_seconds,
        audience="task-instance-logs",
    )

    # Prevent direct access to the logs port
    @flask_app.before_request
    def validate_pre_signed_url():
        try:
            auth = request.headers.get('Authorization')
            if auth is None:
                logger.warning("The Authorization header is missing: %s.",
                               request.headers)
                abort(403)
            payload = signer.verify_token(auth)
            token_filename = payload.get("filename")
            request_filename = request.view_args['filename']
            if token_filename is None:
                logger.warning(
                    "The payload does not contain 'filename' key: %s.",
                    payload)
                abort(403)
            if token_filename != request_filename:
                logger.warning(
                    "The payload log_relative_path key is different than the one in token:"
                    "Request path: %s. Token path: %s.",
                    request_filename,
                    token_filename,
                )
                abort(403)
        except InvalidAudienceError:
            logger.warning("Invalid audience for the request", exc_info=True)
            abort(403)
        except InvalidSignatureError:
            logger.warning("The signature of the request was wrong",
                           exc_info=True)
            abort(403)
        except ImmatureSignatureError:
            logger.warning(
                "The signature of the request was sent from the future",
                exc_info=True)
            abort(403)
        except ExpiredSignatureError:
            logger.warning(
                "The signature of the request has expired. Make sure that all components "
                "in your system have synchronized clocks. "
                "See more at %s",
                get_docs_url("configurations-ref.html#secret-key"),
                exc_info=True,
            )
            abort(403)
        except InvalidIssuedAtError:
            logger.warning(
                "The request was issues in the future. Make sure that all components "
                "in your system have synchronized clocks. "
                "See more at %s",
                get_docs_url("configurations-ref.html#secret-key"),
                exc_info=True,
            )
            abort(403)
        except Exception:
            logger.warning("Unknown error", exc_info=True)
            abort(403)

    @flask_app.route('/log/<path:filename>')
    def serve_logs_view(filename):
        return send_from_directory(log_directory,
                                   filename,
                                   mimetype="application/json",
                                   as_attachment=False)

    return flask_app
Exemple #35
0
from sqlalchemy import (
        Column, Integer, String, DateTime, ForeignKey)
from sqlalchemy import func
from sqlalchemy.orm.session import make_transient

from airflow.executors import DEFAULT_EXECUTOR
from airflow.configuration import conf
from airflow import models
from airflow import settings
from airflow import utils
import socket
from airflow.utils import State


Base = models.Base
ID_LEN = conf.getint('misc', 'ID_LEN')

class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """

    __tablename__ = "job"

    id = Column(Integer, primary_key=True)
    dag_id = Column(String(ID_LEN),)
    state = Column(String(20))
    job_type = Column(String(30))
Exemple #36
0
class DagRun(Base, LoggingMixin):
    """
    DagRun describes an instance of a Dag. It can be created
    by the scheduler (for regular runs) or by an external trigger
    """

    __tablename__ = "dag_run"

    id = Column(Integer, primary_key=True)
    dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
    queued_at = Column(UtcDateTime)
    execution_date = Column(UtcDateTime,
                            default=timezone.utcnow,
                            nullable=False)
    start_date = Column(UtcDateTime)
    end_date = Column(UtcDateTime)
    _state = Column('state', String(50), default=State.QUEUED)
    run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
    creating_job_id = Column(Integer)
    external_trigger = Column(Boolean, default=True)
    run_type = Column(String(50), nullable=False)
    conf = Column(PickleType)
    # These two must be either both NULL or both datetime.
    data_interval_start = Column(UtcDateTime)
    data_interval_end = Column(UtcDateTime)
    # When a scheduler last attempted to schedule TIs for this DagRun
    last_scheduling_decision = Column(UtcDateTime)
    dag_hash = Column(String(32))
    # Foreign key to LogTemplate. DagRun rows created prior to this column's
    # existence have this set to NULL. Later rows automatically populate this on
    # insert to point to the latest LogTemplate entry.
    log_template_id = Column(
        Integer,
        ForeignKey("log_template.id",
                   name="task_instance_log_template_id_fkey",
                   ondelete="NO ACTION"),
        default=select([func.max(LogTemplate.__table__.c.id)]),
    )

    # Remove this `if` after upgrading Sphinx-AutoAPI
    if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
        dag: "Optional[DAG]"
    else:
        dag: "Optional[DAG]" = None

    __table_args__ = (
        Index('dag_id_state', dag_id, _state),
        UniqueConstraint('dag_id',
                         'execution_date',
                         name='dag_run_dag_id_execution_date_key'),
        UniqueConstraint('dag_id', 'run_id', name='dag_run_dag_id_run_id_key'),
        Index('idx_last_scheduling_decision', last_scheduling_decision),
        Index('idx_dag_run_dag_id', dag_id),
        Index(
            'idx_dag_run_running_dags',
            'state',
            'dag_id',
            postgresql_where=text("state='running'"),
            mssql_where=text("state='running'"),
            sqlite_where=text("state='running'"),
        ),
        # since mysql lacks filtered/partial indices, this creates a
        # duplicate index on mysql. Not the end of the world
        Index(
            'idx_dag_run_queued_dags',
            'state',
            'dag_id',
            postgresql_where=text("state='queued'"),
            mssql_where=text("state='queued'"),
            sqlite_where=text("state='queued'"),
        ),
    )

    task_instances = relationship(
        TI,
        back_populates="dag_run",
        cascade='save-update, merge, delete, delete-orphan')

    DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
        'scheduler',
        'max_dagruns_per_loop_to_schedule',
        fallback=20,
    )

    def __init__(
        self,
        dag_id: Optional[str] = None,
        run_id: Optional[str] = None,
        queued_at: Union[datetime, None, ArgNotSet] = NOTSET,
        execution_date: Optional[datetime] = None,
        start_date: Optional[datetime] = None,
        external_trigger: Optional[bool] = None,
        conf: Optional[Any] = None,
        state: Optional[DagRunState] = None,
        run_type: Optional[str] = None,
        dag_hash: Optional[str] = None,
        creating_job_id: Optional[int] = None,
        data_interval: Optional[Tuple[datetime, datetime]] = None,
    ):
        if data_interval is None:
            # Legacy: Only happen for runs created prior to Airflow 2.2.
            self.data_interval_start = self.data_interval_end = None
        else:
            self.data_interval_start, self.data_interval_end = data_interval

        self.dag_id = dag_id
        self.run_id = run_id
        self.execution_date = execution_date
        self.start_date = start_date
        self.external_trigger = external_trigger
        self.conf = conf or {}
        if state is not None:
            self.state = state
        if queued_at is NOTSET:
            self.queued_at = timezone.utcnow(
            ) if state == State.QUEUED else None
        else:
            self.queued_at = queued_at
        self.run_type = run_type
        self.dag_hash = dag_hash
        self.creating_job_id = creating_job_id
        super().__init__()

    def __repr__(self):
        return (
            '<DagRun {dag_id} @ {execution_date}: {run_id}, externally triggered: {external_trigger}>'
        ).format(
            dag_id=self.dag_id,
            execution_date=self.execution_date,
            run_id=self.run_id,
            external_trigger=self.external_trigger,
        )

    @property
    def logical_date(self) -> datetime:
        return self.execution_date

    def get_state(self):
        return self._state

    def set_state(self, state: DagRunState):
        if state not in State.dag_states:
            raise ValueError(f"invalid DagRun state: {state}")
        if self._state != state:
            self._state = state
            self.end_date = timezone.utcnow(
            ) if self._state in State.finished else None
            if state == State.QUEUED:
                self.queued_at = timezone.utcnow()

    @declared_attr
    def state(self):
        return synonym('_state',
                       descriptor=property(self.get_state, self.set_state))

    @provide_session
    def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
        """
        Reloads the current dagrun from the database

        :param session: database session
        """
        dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id,
                                          DagRun.run_id == self.run_id).one()
        self.id = dr.id
        self.state = dr.state

    @classmethod
    @provide_session
    def active_runs_of_dags(cls,
                            dag_ids=None,
                            only_running=False,
                            session=None) -> Dict[str, int]:
        """Get the number of active dag runs for each dag."""
        query = session.query(cls.dag_id, func.count('*'))
        if dag_ids is not None:
            # 'set' called to avoid duplicate dag_ids, but converted back to 'list'
            # because SQLAlchemy doesn't accept a set here.
            query = query.filter(cls.dag_id.in_(list(set(dag_ids))))
        if only_running:
            query = query.filter(cls.state == State.RUNNING)
        else:
            query = query.filter(cls.state.in_([State.RUNNING, State.QUEUED]))
        query = query.group_by(cls.dag_id)
        return {dag_id: count for dag_id, count in query.all()}

    @classmethod
    def next_dagruns_to_examine(
        cls,
        state: DagRunState,
        session: Session,
        max_number: Optional[int] = None,
    ):
        """
        Return the next DagRuns that the scheduler should attempt to schedule.

        This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
        query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
        the transaction is committed it will be unlocked.

        :rtype: list[airflow.models.DagRun]
        """
        from airflow.models.dag import DagModel

        if max_number is None:
            max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

        # TODO: Bake this query, it is run _A lot_
        query = (session.query(cls).filter(
            cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB).join(
                DagModel, DagModel.dag_id == cls.dag_id).filter(
                    DagModel.is_paused == false(),
                    DagModel.is_active == true()))
        if state == State.QUEUED:
            # For dag runs in the queued state, we check if they have reached the max_active_runs limit
            # and if so we drop them
            running_drs = (session.query(
                DagRun.dag_id,
                func.count(DagRun.state).label('num_running')).filter(
                    DagRun.state == DagRunState.RUNNING).group_by(
                        DagRun.dag_id).subquery())
            query = query.outerjoin(
                running_drs, running_drs.c.dag_id == DagRun.dag_id).filter(
                    func.coalesce(running_drs.c.num_running, 0) <
                    DagModel.max_active_runs)
        query = query.order_by(
            nulls_first(cls.last_scheduling_decision, session=session),
            cls.execution_date,
        )

        if not settings.ALLOW_FUTURE_EXEC_DATES:
            query = query.filter(DagRun.execution_date <= func.now())

        return with_row_locks(query.limit(max_number),
                              of=cls,
                              session=session,
                              **skip_locked(session=session))

    @classmethod
    @provide_session
    def find(
        cls,
        dag_id: Optional[Union[str, List[str]]] = None,
        run_id: Optional[Iterable[str]] = None,
        execution_date: Optional[Union[datetime, Iterable[datetime]]] = None,
        state: Optional[DagRunState] = None,
        external_trigger: Optional[bool] = None,
        no_backfills: bool = False,
        run_type: Optional[DagRunType] = None,
        session: Session = NEW_SESSION,
        execution_start_date: Optional[datetime] = None,
        execution_end_date: Optional[datetime] = None,
    ) -> List["DagRun"]:
        """
        Returns a set of dag runs for the given search criteria.

        :param dag_id: the dag_id or list of dag_id to find dag runs for
        :param run_id: defines the run id for this dag run
        :param run_type: type of DagRun
        :param execution_date: the execution date
        :param state: the state of the dag run
        :param external_trigger: whether this dag run is externally triggered
        :param no_backfills: return no backfills (True), return all (False).
            Defaults to False
        :param session: database session
        :param execution_start_date: dag run that was executed from this date
        :param execution_end_date: dag run that was executed until this date
        """
        qry = session.query(cls)
        dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
        if dag_ids:
            qry = qry.filter(cls.dag_id.in_(dag_ids))

        if is_container(run_id):
            qry = qry.filter(cls.run_id.in_(run_id))
        elif run_id is not None:
            qry = qry.filter(cls.run_id == run_id)
        if is_container(execution_date):
            qry = qry.filter(cls.execution_date.in_(execution_date))
        elif execution_date is not None:
            qry = qry.filter(cls.execution_date == execution_date)
        if execution_start_date and execution_end_date:
            qry = qry.filter(
                cls.execution_date.between(execution_start_date,
                                           execution_end_date))
        elif execution_start_date:
            qry = qry.filter(cls.execution_date >= execution_start_date)
        elif execution_end_date:
            qry = qry.filter(cls.execution_date <= execution_end_date)
        if state:
            qry = qry.filter(cls.state == state)
        if external_trigger is not None:
            qry = qry.filter(cls.external_trigger == external_trigger)
        if run_type:
            qry = qry.filter(cls.run_type == run_type)
        if no_backfills:
            qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB)

        return qry.order_by(cls.execution_date).all()

    @classmethod
    @provide_session
    def find_duplicate(
        cls,
        dag_id: str,
        run_id: str,
        execution_date: datetime,
        session: Session = NEW_SESSION,
    ) -> Optional['DagRun']:
        """
        Return an existing run for the DAG with a specific run_id or execution_date.

        *None* is returned if no such DAG run is found.

        :param dag_id: the dag_id to find duplicates for
        :param run_id: defines the run id for this dag run
        :param execution_date: the execution date
        :param session: database session
        """
        return (session.query(cls).filter(
            cls.dag_id == dag_id,
            or_(cls.run_id == run_id, cls.execution_date == execution_date),
        ).one_or_none())

    @staticmethod
    def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
        """Generate Run ID based on Run Type and Execution Date"""
        return f"{run_type}__{execution_date.isoformat()}"

    @provide_session
    def get_task_instances(
        self,
        state: Optional[Iterable[Optional[TaskInstanceState]]] = None,
        session: Session = NEW_SESSION,
    ) -> Iterable[TI]:
        """Returns the task instances for this dag run"""
        tis = (session.query(TI).options(joinedload(TI.dag_run)).filter(
            TI.dag_id == self.dag_id,
            TI.run_id == self.run_id,
        ))

        if state:
            if isinstance(state, str):
                tis = tis.filter(TI.state == state)
            else:
                # this is required to deal with NULL values
                if State.NONE in state:
                    if all(x is None for x in state):
                        tis = tis.filter(TI.state.is_(None))
                    else:
                        not_none_state = [s for s in state if s]
                        tis = tis.filter(
                            or_(TI.state.in_(not_none_state),
                                TI.state.is_(None)))
                else:
                    tis = tis.filter(TI.state.in_(state))

        if self.dag and self.dag.partial:
            tis = tis.filter(TI.task_id.in_(self.dag.task_ids))
        return tis.all()

    @provide_session
    def get_task_instance(
        self,
        task_id: str,
        session: Session = NEW_SESSION,
        *,
        map_index: int = -1,
    ) -> Optional[TI]:
        """
        Returns the task instance specified by task_id for this dag run

        :param task_id: the task id
        :param session: Sqlalchemy ORM Session
        """
        return (session.query(TI).filter_by(dag_id=self.dag_id,
                                            run_id=self.run_id,
                                            task_id=task_id,
                                            map_index=map_index).one_or_none())

    def get_dag(self) -> "DAG":
        """
        Returns the Dag associated with this DagRun.

        :return: DAG
        """
        if not self.dag:
            raise AirflowException(
                f"The DAG (.dag) for {self} needs to be set")

        return self.dag

    @provide_session
    def get_previous_dagrun(
            self,
            state: Optional[DagRunState] = None,
            session: Session = NEW_SESSION) -> Optional['DagRun']:
        """The previous DagRun, if there is one"""
        filters = [
            DagRun.dag_id == self.dag_id,
            DagRun.execution_date < self.execution_date,
        ]
        if state is not None:
            filters.append(DagRun.state == state)
        return session.query(DagRun).filter(*filters).order_by(
            DagRun.execution_date.desc()).first()

    @provide_session
    def get_previous_scheduled_dagrun(self,
                                      session: Session = NEW_SESSION
                                      ) -> Optional['DagRun']:
        """The previous, SCHEDULED DagRun, if there is one"""
        return (session.query(DagRun).filter(
            DagRun.dag_id == self.dag_id,
            DagRun.execution_date < self.execution_date,
            DagRun.run_type != DagRunType.MANUAL,
        ).order_by(DagRun.execution_date.desc()).first())

    @provide_session
    def update_state(
        self,
        session: Session = NEW_SESSION,
        execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[callback_requests.DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm
        with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"):
            dag = self.get_dag()
            info = self.task_instance_scheduling_decisions(session)

            tis = info.tis
            schedulable_tis = info.schedulable_tis
            changed_tis = info.changed_tis
            finished_tasks = info.finished_tasks
            unfinished_tasks = info.unfinished_tasks

            none_depends_on_past = all(
                not t.task.depends_on_past
                for t in unfinished_tasks  # type: ignore[has-type]
            )
            none_task_concurrency = all(
                t.task.max_active_tis_per_dag is None
                for t in unfinished_tasks  # type: ignore[has-type]
            )
            none_deferred = all(t.state != State.DEFERRED
                                for t in unfinished_tasks)

            if unfinished_tasks and none_depends_on_past and none_task_concurrency and none_deferred:
                # small speed up
                are_runnable_tasks = (schedulable_tis
                                      or self._are_premature_tis(
                                          unfinished_tasks, finished_tasks,
                                          session) or changed_tis)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(leaf_ti.state in State.failed_states
                                        for leaf_ti in leaf_tis):
            self.log.error('Marking run %s failed', self)
            self.set_state(DagRunState.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='task_failure',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=True,
                    msg='task_failure',
                )

        # if all leaves succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(leaf_ti.state in State.success_states
                                          for leaf_ti in leaf_tis):
            self.log.info('Marking run %s successful', self)
            self.set_state(DagRunState.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=True,
                                    reason='success',
                                    session=session)
            elif dag.has_on_success_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=False,
                    msg='success',
                )

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tasks and none_depends_on_past
              and none_task_concurrency and none_deferred
              and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(DagRunState.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='all_tasks_deadlocked',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked',
                )

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(DagRunState.RUNNING)

        if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
            msg = ("DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
                   "run_start_date=%s, run_end_date=%s, run_duration=%s, "
                   "state=%s, external_trigger=%s, run_type=%s, "
                   "data_interval_start=%s, data_interval_end=%s, dag_hash=%s")
            self.log.info(
                msg,
                self.dag_id,
                self.execution_date,
                self.run_id,
                self.start_date,
                self.end_date,
                (self.end_date - self.start_date).total_seconds()
                if self.start_date and self.end_date else None,
                self._state,
                self.external_trigger,
                self.run_type,
                self.data_interval_start,
                self.data_interval_end,
                self.dag_hash,
            )

        self._emit_true_scheduling_delay_stats_for_finished_state(
            finished_tasks)
        self._emit_duration_stats_for_finished_state()

        session.merge(self)

        return schedulable_tis, callback

    @provide_session
    def task_instance_scheduling_decisions(self,
                                           session: Session = NEW_SESSION
                                           ) -> TISchedulingDecision:

        schedulable_tis: List[TI] = []
        changed_tis = False

        tis = list(
            self.get_task_instances(session=session, state=State.task_states))
        self.log.debug("number of tis tasks for %s: %s task(s)", self,
                       len(tis))
        for ti in tis:
            try:
                ti.task = self.get_dag().get_task(ti.task_id)
            except TaskNotFound:
                self.log.warning(
                    "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                    ti, ti.dag_id)
                ti.state = State.REMOVED
                session.flush()

        unfinished_tasks = [t for t in tis if t.state in State.unfinished]
        finished_tasks = [t for t in tis if t.state in State.finished]
        if unfinished_tasks:
            scheduleable_tasks = [
                ut for ut in unfinished_tasks
                if ut.state in SCHEDULEABLE_STATES
            ]
            self.log.debug("number of scheduleable tasks for %s: %s task(s)",
                           self, len(scheduleable_tasks))
            schedulable_tis, changed_tis = self._get_ready_tis(
                scheduleable_tasks, finished_tasks, session)

        return TISchedulingDecision(
            tis=tis,
            schedulable_tis=schedulable_tis,
            changed_tis=changed_tis,
            unfinished_tasks=unfinished_tasks,
            finished_tasks=finished_tasks,
        )

    def _get_ready_tis(
        self,
        scheduleable_tasks: List[TI],
        finished_tasks: List[TI],
        session: Session,
    ) -> Tuple[List[TI], bool]:
        old_states = {}
        ready_tis: List[TI] = []
        changed_tis = False

        if not scheduleable_tasks:
            return ready_tis, changed_tis

        # Check dependencies
        for st in scheduleable_tasks:
            old_state = st.state
            if st.are_dependencies_met(
                    dep_context=DepContext(flag_upstream_failed=True,
                                           finished_tasks=finished_tasks),
                    session=session,
            ):
                ready_tis.append(st)
            else:
                old_states[st.key] = old_state

        # Check if any ti changed state
        tis_filter = TI.filter_for_tis(old_states.keys())
        if tis_filter is not None:
            fresh_tis = session.query(TI).filter(tis_filter).all()
            changed_tis = any(ti.state != old_states[ti.key]
                              for ti in fresh_tis)

        return ready_tis, changed_tis

    def _are_premature_tis(
        self,
        unfinished_tasks: List[TI],
        finished_tasks: List[TI],
        session: Session,
    ) -> bool:
        # there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are
        # not ready yet so we set the flags to count them in
        for ut in unfinished_tasks:
            if ut.are_dependencies_met(
                    dep_context=DepContext(
                        flag_upstream_failed=True,
                        ignore_in_retry_period=True,
                        ignore_in_reschedule_period=True,
                        finished_tasks=finished_tasks,
                    ),
                    session=session,
            ):
                return True
        return False

    def _emit_true_scheduling_delay_stats_for_finished_state(
            self, finished_tis):
        """
        This is a helper method to emit the true scheduling delay stats, which is defined as
        the time when the first task in DAG starts minus the expected DAG run datetime.
        This method will be used in the update_state method when the state of the DagRun
        is updated to a completed status (either success or failure). The method will find the first
        started task within the DAG and calculate the expected DagRun start time (based on
        dag.execution_date & dag.timetable), and minus these two values to get the delay.
        The emitted data may contains outlier (e.g. when the first task was cleared, so
        the second task's start_date will be used), but we can get rid of the outliers
        on the stats side through the dashboards tooling built.
        Note, the stat will only be emitted if the DagRun is a scheduler triggered one
        (i.e. external_trigger is False).
        """
        if self.state == State.RUNNING:
            return
        if self.external_trigger:
            return
        if not finished_tis:
            return

        try:
            dag = self.get_dag()

            if not self.dag.timetable.periodic:
                # We can't emit this metric if there is no following schedule to calculate from!
                return

            ordered_tis_by_start_date = [
                ti for ti in finished_tis if ti.start_date
            ]
            ordered_tis_by_start_date.sort(key=lambda ti: ti.start_date,
                                           reverse=False)
            first_start_date = ordered_tis_by_start_date[0].start_date
            if first_start_date:
                # TODO: Logically, this should be DagRunInfo.run_after, but the
                # information is not stored on a DagRun, only before the actual
                # execution on DagModel.next_dagrun_create_after. We should add
                # a field on DagRun for this instead of relying on the run
                # always happening immediately after the data interval.
                data_interval_end = dag.get_run_data_interval(self).end
                true_delay = first_start_date - data_interval_end
                if true_delay.total_seconds() > 0:
                    Stats.timing(
                        f'dagrun.{dag.dag_id}.first_task_scheduling_delay',
                        true_delay)
        except Exception as e:
            self.log.warning(
                f'Failed to record first_task_scheduling_delay metric:\n{e}')

    def _emit_duration_stats_for_finished_state(self):
        if self.state == State.RUNNING:
            return
        if self.start_date is None:
            self.log.warning(
                'Failed to record duration of %s: start_date is not set.',
                self)
            return
        if self.end_date is None:
            self.log.warning(
                'Failed to record duration of %s: end_date is not set.', self)
            return

        duration = self.end_date - self.start_date
        if self.state == State.SUCCESS:
            Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration)
        elif self.state == State.FAILED:
            Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)

    @provide_session
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                        ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously removed from DAG '%s'",
                    ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        def task_filter(task: "BaseOperator"):
            return task.task_id not in task_ids and (
                self.is_backfill or task.start_date <= self.execution_date)

        created_counts: Dict[str, int] = defaultdict(int)

        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

        if hook_is_noop:

            def create_ti_mapping(task: "BaseOperator"):
                created_counts[task.task_type] += 1
                return TI.insert_mapping(self.run_id, task)

        else:

            def create_ti(task: "BaseOperator") -> TI:
                ti = TI(task, run_id=self.run_id)
                task_instance_mutation_hook(ti)
                created_counts[ti.operator] += 1
                return ti

        # Create missing tasks
        tasks = list(filter(task_filter, dag.task_dict.values()))
        try:
            if hook_is_noop:
                session.bulk_insert_mappings(TI, map(create_ti_mapping, tasks))
            else:
                session.bulk_save_objects(map(create_ti, tasks))

            for task_type, count in created_counts.items():
                Stats.incr(f"task_instance_created-{task_type}", count)
            session.flush()
        except IntegrityError:
            self.log.info(
                'Hit IntegrityError while creating the TIs for %s- %s',
                dag.dag_id,
                self.run_id,
                exc_info=True,
            )
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()

    @staticmethod
    def get_run(session: Session, dag_id: str,
                execution_date: datetime) -> Optional['DagRun']:
        """
        Get a single DAG Run

        :meta private:
        :param session: Sqlalchemy ORM Session
        :param dag_id: DAG ID
        :param execution_date: execution date
        :return: DagRun corresponding to the given dag_id and execution date
            if one exists. None otherwise.
        :rtype: airflow.models.DagRun
        """
        warnings.warn(
            "This method is deprecated. Please use SQLAlchemy directly",
            DeprecationWarning,
            stacklevel=2,
        )
        return (session.query(DagRun).filter(
            DagRun.dag_id == dag_id,
            DagRun.external_trigger == False,  # noqa
            DagRun.execution_date == execution_date,
        ).first())

    @property
    def is_backfill(self) -> bool:
        return self.run_type == DagRunType.BACKFILL_JOB

    @classmethod
    @provide_session
    def get_latest_runs(cls, session=None) -> List['DagRun']:
        """Returns the latest DagRun for each DAG"""
        subquery = (session.query(
            cls.dag_id,
            func.max(cls.execution_date).label('execution_date')).group_by(
                cls.dag_id).subquery())
        return (session.query(cls).join(
            subquery,
            and_(cls.dag_id == subquery.c.dag_id,
                 cls.execution_date == subquery.c.execution_date),
        ).all())

    @provide_session
    def schedule_tis(self,
                     schedulable_tis: Iterable[TI],
                     session: Session = NEW_SESSION) -> int:
        """
        Set the given task instances in to the scheduled state.

        Each element of ``schedulable_tis`` should have it's ``task`` attribute already set.

        Any DummyOperator without callbacks is instead set straight to the success state.

        All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
        is the caller's responsibility to call this function only with TIs from a single dag run.
        """
        # Get list of TI IDs that do not need to executed, these are
        # tasks using DummyOperator and without on_execute_callback / on_success_callback
        dummy_ti_ids = []
        schedulable_ti_ids = []
        for ti in schedulable_tis:
            if (ti.task.inherits_from_dummy_operator
                    and not ti.task.on_execute_callback
                    and not ti.task.on_success_callback):
                dummy_ti_ids.append(ti.task_id)
            else:
                schedulable_ti_ids.append(ti.task_id)

        count = 0

        if schedulable_ti_ids:
            count += (session.query(TI).filter(
                TI.dag_id == self.dag_id,
                TI.run_id == self.run_id,
                TI.task_id.in_(schedulable_ti_ids),
            ).update({TI.state: State.SCHEDULED}, synchronize_session=False))

        # Tasks using DummyOperator should not be executed, mark them as success
        if dummy_ti_ids:
            count += (session.query(TI).filter(
                TI.dag_id == self.dag_id,
                TI.run_id == self.run_id,
                TI.task_id.in_(dummy_ti_ids),
            ).update(
                {
                    TI.state: State.SUCCESS,
                    TI.start_date: timezone.utcnow(),
                    TI.end_date: timezone.utcnow(),
                    TI.duration: 0,
                },
                synchronize_session=False,
            ))

        return count

    @provide_session
    def get_log_filename_template(self,
                                  *,
                                  session: Session = NEW_SESSION) -> str:
        if self.log_template_id is None:  # DagRun created before LogTemplate introduction.
            template = session.query(LogTemplate.filename).order_by(
                LogTemplate.id).limit(1).scalar()
        else:
            template = session.query(LogTemplate.filename).filter_by(
                id=self.log_template_id).scalar()
        if template is None:
            raise AirflowException(
                f"No log_template entry found for ID {self.log_template_id!r}. "
                f"Please make sure you set up the metadatabase correctly.")
        return template

    @provide_session
    def get_task_prefix_template(self,
                                 *,
                                 session: Session = NEW_SESSION) -> str:
        if self.log_template_id is None:  # DagRun created before LogTemplate introduction.
            template = session.query(LogTemplate.task_prefix).order_by(
                LogTemplate.id).limit(1).scalar()
        else:
            template = session.query(LogTemplate.task_prefix).filter_by(
                id=self.log_template_id).scalar()
        if template is None:
            raise AirflowException(
                f"No log_template entry found for ID {self.log_template_id!r}. "
                f"Please make sure you set up the metadatabase correctly.")
        return template
    def start(self):
        self.task_queue = Queue()
        self.result_queue = Queue()
        framework = mesos_pb2.FrameworkInfo()
        framework.user = ""

        if not conf.get("mesos", "MASTER"):
            logging.error("Expecting mesos master URL for mesos executor")
            raise AirflowException("mesos.master not provided for mesos executor")

        master = conf.get("mesos", "MASTER")

        if not conf.get("mesos", "FRAMEWORK_NAME"):
            framework.name = "Airflow"
        else:
            framework.name = conf.get("mesos", "FRAMEWORK_NAME")

        if not conf.get("mesos", "TASK_CPU"):
            task_cpu = 1
        else:
            task_cpu = conf.getint("mesos", "TASK_CPU")

        if not conf.get("mesos", "TASK_MEMORY"):
            task_memory = 256
        else:
            task_memory = conf.getint("mesos", "TASK_MEMORY")

        if conf.getboolean("mesos", "CHECKPOINT"):
            framework.checkpoint = True
        else:
            framework.checkpoint = False

        logging.info(
            "MesosFramework master : %s, name : %s, cpu : %s, mem : %s, checkpoint : %s",
            master,
            framework.name,
            str(task_cpu),
            str(task_memory),
            str(framework.checkpoint),
        )

        implicit_acknowledgements = 1

        if conf.getboolean("mesos", "AUTHENTICATE"):
            if not conf.get("mesos", "DEFAULT_PRINCIPAL"):
                logging.error("Expecting authentication principal in the environment")
                raise AirflowException("mesos.default_principal not provided in authenticated mode")
            if not conf.get("mesos", "DEFAULT_SECRET"):
                logging.error("Expecting authentication secret in the environment")
                raise AirflowException("mesos.default_secret not provided in authenticated mode")

            credential = mesos_pb2.Credential()
            credential.principal = conf.get("mesos", "DEFAULT_PRINCIPAL")
            credential.secret = conf.get("mesos", "DEFAULT_SECRET")

            framework.principal = credential.principal

            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory),
                framework,
                master,
                implicit_acknowledgements,
                credential,
            )
        else:
            framework.principal = "Airflow"
            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue, task_cpu, task_memory),
                framework,
                master,
                implicit_acknowledgements,
            )

        self.mesos_driver = driver
        self.mesos_driver.start()
Exemple #38
0
    def __init__(
        self,
        dag_directory: Union[str, "pathlib.Path"],
        max_runs: int,
        processor_timeout: timedelta,
        signal_conn: MultiprocessingConnection,
        dag_ids: Optional[List[str]],
        pickle_dags: bool,
        async_mode: bool = True,
    ):
        super().__init__()
        self._file_paths: List[str] = []
        self._file_path_queue: List[str] = []
        self._dag_directory = dag_directory
        self._max_runs = max_runs
        self._signal_conn = signal_conn
        self._pickle_dags = pickle_dags
        self._dag_ids = dag_ids
        self._async_mode = async_mode
        self._parsing_start_time: Optional[int] = None

        # Set the signal conn in to non-blocking mode, so that attempting to
        # send when the buffer is full errors, rather than hangs for-ever
        # attempting to send (this is to avoid deadlocks!)
        #
        # Don't do this in sync_mode, as we _need_ the DagParsingStat sent to
        # continue the scheduler
        if self._async_mode:
            os.set_blocking(self._signal_conn.fileno(), False)

        self._parallelism = conf.getint('scheduler', 'parsing_processes')
        if conf.get('core', 'sql_alchemy_conn').startswith(
                'sqlite') and self._parallelism > 1:
            self.log.warning(
                "Because we cannot use more than 1 thread (parsing_processes = "
                "%d) when using sqlite. So we set parallelism to 1.",
                self._parallelism,
            )
            self._parallelism = 1

        # Parse and schedule each file no faster than this interval.
        self._file_process_interval = conf.getint('scheduler',
                                                  'min_file_process_interval')
        # How often to print out DAG file processing stats to the log. Default to
        # 30 seconds.
        self.print_stats_interval = conf.getint('scheduler',
                                                'print_stats_interval')
        # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
        self._zombie_threshold_secs = conf.getint(
            'scheduler', 'scheduler_zombie_task_threshold')

        # Map from file path to the processor
        self._processors: Dict[str, DagFileProcessorProcess] = {}

        self._num_run = 0

        # Map from file path to stats about the file
        self._file_stats: Dict[str, DagFileStat] = {}

        self._last_zombie_query_time = None
        # Last time that the DAG dir was traversed to look for files
        self.last_dag_dir_refresh_time = timezone.make_aware(
            datetime.fromtimestamp(0))
        # Last time stats were printed
        self.last_stat_print_time = 0
        # TODO: Remove magic number
        self._zombie_query_interval = 10
        # How long to wait before timing out a process to parse a DAG file
        self._processor_timeout = processor_timeout

        # How often to scan the DAGs directory for new files. Default to 5 minutes.
        self.dag_dir_list_interval = conf.getint('scheduler',
                                                 'dag_dir_list_interval')

        # Mapping file name and callbacks requests
        self._callback_to_execute: Dict[
            str, List[CallbackRequest]] = defaultdict(list)

        self._log = logging.getLogger('airflow.processor_manager')

        self.waitables: Dict[Any, Union[MultiprocessingConnection,
                                        DagFileProcessorProcess]] = {
                                            self._signal_conn:
                                            self._signal_conn,
                                        }
Exemple #39
0
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.kube_env_vars = configuration_dict.get('kubernetes_environment_variables', {})
        self.airflow_home = configuration.get(self.core_section, 'airflow_home')
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section, 'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(
            self.kubernetes_section, 'worker_container_tag')
        self.kube_image = '{}:{}'.format(
            self.worker_container_repository, self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy"
        )
        self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
        self.kube_annotations = configuration_dict.get('kubernetes_annotations', {})
        self.delete_worker_pods = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section, 'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section, 'dags_in_image')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(self.kubernetes_section,
                                                    'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section, 'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(self.kubernetes_section,
                                                           'git_ssh_known_hosts_configmap_name')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section, 'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section, 'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(
            self.kubernetes_section, 'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(
            self.kubernetes_section, 'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section, 'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section, 'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section, 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(
            self.kubernetes_section, 'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section, 'airflow_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        self._validate()
Exemple #40
0
class DagBag(BaseDagBag, LoggingMixin):
    """
    A dagbag is a collection of dags, parsed out of a folder tree and has high
    level configuration settings, like what database to use as a backend and
    what executor to use to fire off tasks. This makes it easier to run
    distinct environments for say production and development, tests, or for
    different teams or security profiles. What would have been system level
    settings are now dagbag level so that one system can run multiple,
    independent settings sets.

    :param dag_folder: the folder to scan to find DAGs
    :type dag_folder: unicode
    :param include_examples: whether to include the examples that ship
        with airflow or not
    :type include_examples: bool
    :param include_smart_sensor: whether to include the smart sensor native
        DAGs that create the smart sensor operators for whole cluster
    :type include_smart_sensor: bool
    :param read_dags_from_db: Read DAGs from DB if store_serialized_dags is ``True``.
        If ``False`` DAGs are read from python files. This property is not used when
        determining whether or not to write Serialized DAGs, that is done by checking
        the config ``store_serialized_dags``.
    :type read_dags_from_db: bool
    """

    DAGBAG_IMPORT_TIMEOUT = conf.getint('core', 'DAGBAG_IMPORT_TIMEOUT')
    SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint(
        'scheduler', 'scheduler_zombie_task_threshold')

    def __init__(
        self,
        dag_folder: Optional[str] = None,
        include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
        include_smart_sensor: bool = conf.getboolean('smart_sensor',
                                                     'USE_SMART_SENSOR'),
        safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
        read_dags_from_db: bool = False,
        store_serialized_dags: Optional[bool] = None,
    ):
        # Avoid circular import
        from airflow.models.dag import DAG
        super().__init__()

        if store_serialized_dags:
            warnings.warn(
                "The store_serialized_dags parameter has been deprecated. "
                "You should pass the read_dags_from_db parameter.",
                DeprecationWarning,
                stacklevel=2)
            read_dags_from_db = store_serialized_dags

        dag_folder = dag_folder or settings.DAGS_FOLDER
        self.dag_folder = dag_folder
        self.dags: Dict[str, DAG] = {}
        # the file's last modified timestamp when we last read it
        self.file_last_changed: Dict[str, datetime] = {}
        self.import_errors: Dict[str, str] = {}
        self.has_logged = False
        self.read_dags_from_db = read_dags_from_db
        # Only used by read_dags_from_db=True
        self.dags_last_fetched: Dict[str, datetime] = {}

        self.collect_dags(dag_folder=dag_folder,
                          include_examples=include_examples,
                          include_smart_sensor=include_smart_sensor,
                          safe_mode=safe_mode)

    def size(self) -> int:
        """
        :return: the amount of dags contained in this dagbag
        """
        return len(self.dags)

    @property
    def store_serialized_dags(self) -> bool:
        """Whether or not to read dags from DB"""
        warnings.warn(
            "The store_serialized_dags property has been deprecated. "
            "Use read_dags_from_db instead.",
            DeprecationWarning,
            stacklevel=2)
        return self.read_dags_from_db

    @property
    def dag_ids(self) -> List[str]:
        return list(self.dags.keys())

    def get_dag(self, dag_id):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        if self.read_dags_from_db:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel
            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                self._add_dag_from_db(dag_id=dag_id)
                return self.dags.get(dag_id)

            # If DAG is in the DagBag, check the following
            # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
            # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
            # 3. if (2) is yes, fetch the Serialized DAG.
            min_serialized_dag_fetch_secs = timedelta(
                seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
            if (dag_id in self.dags_last_fetched
                    and timezone.utcnow() > self.dags_last_fetched[dag_id] +
                    min_serialized_dag_fetch_secs):
                sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(
                    dag_id=dag_id)
                if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
                    self._add_dag_from_db(dag_id=dag_id)

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id

        # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
        orm_dag = DagModel.get_current(root_dag_id)
        if not orm_dag:
            return self.dags.get(dag_id)

        # If the dag corresponding to root_dag_id is absent or expired
        is_missing = root_dag_id not in self.dags
        is_expired = (orm_dag.last_expired
                      and dag.last_loaded < orm_dag.last_expired)
        if is_missing or is_expired:
            # Reprocess source file
            found_dags = self.process_file(filepath=correct_maybe_zipped(
                orm_dag.fileloc),
                                           only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [
                    found_dag.dag_id for found_dag in found_dags
            ]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)

    def _add_dag_from_db(self, dag_id: str):
        """Add DAG to DagBag from DB"""
        from airflow.models.serialized_dag import SerializedDagModel
        row = SerializedDagModel.get(dag_id)
        if not row:
            raise ValueError(
                f"DAG '{dag_id}' not found in serialized_dag table")

        dag = row.dag
        for subdag in dag.subdags:
            self.dags[subdag.dag_id] = subdag
        self.dags[dag.dag_id] = dag
        self.dags_last_fetched[dag.dag_id] = timezone.utcnow()

    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        integrate_dag_plugins()

        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return []

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.file_last_changed \
                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
                return []
        except Exception as e:  # pylint: disable=broad-except
            self.log.exception(e)
            return []

        if not zipfile.is_zipfile(filepath):
            mods = self._load_modules_from_file(filepath, safe_mode)
        else:
            mods = self._load_modules_from_zip(filepath, safe_mode)

        found_dags = self._process_modules(filepath, mods,
                                           file_last_changed_on_disk)

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags

    def _load_modules_from_file(self, filepath, safe_mode):
        if not might_contain_dag(filepath, safe_mode):
            # Don't want to spam user with skip messages
            if not self.has_logged:
                self.has_logged = True
                self.log.info("File %s assumed to contain no DAGs. Skipping.",
                              filepath)
            return []

        self.log.debug("Importing %s", filepath)
        org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
        path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest()
        mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}'

        if mod_name in sys.modules:
            del sys.modules[mod_name]

        with timeout(self.DAGBAG_IMPORT_TIMEOUT):
            try:
                loader = importlib.machinery.SourceFileLoader(
                    mod_name, filepath)
                spec = importlib.util.spec_from_loader(mod_name, loader)
                new_module = importlib.util.module_from_spec(spec)
                sys.modules[spec.name] = new_module
                loader.exec_module(new_module)
                return [new_module]
            except Exception as e:  # pylint: disable=broad-except
                self.log.exception("Failed to import: %s", filepath)
                self.import_errors[filepath] = str(e)
        return []

    def _load_modules_from_zip(self, filepath, safe_mode):
        mods = []
        current_zip_file = zipfile.ZipFile(filepath)
        for zip_info in current_zip_file.infolist():
            head, _ = os.path.split(zip_info.filename)
            mod_name, ext = os.path.splitext(zip_info.filename)
            if ext not in [".py", ".pyc"]:
                continue
            if head:
                continue

            if mod_name == '__init__':
                self.log.warning("Found __init__.%s at root of %s", ext,
                                 filepath)

            self.log.debug("Reading %s from %s", zip_info.filename, filepath)

            if not might_contain_dag(zip_info.filename, safe_mode,
                                     current_zip_file):
                # todo: create ignore list
                # Don't want to spam user with skip messages
                if not self.has_logged or True:
                    self.has_logged = True
                    self.log.info(
                        "File %s:%s assumed to contain no DAGs. Skipping.",
                        filepath, zip_info.filename)
                continue

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            try:
                sys.path.insert(0, filepath)
                current_module = importlib.import_module(mod_name)
                mods.append(current_module)
            except Exception as e:  # pylint: disable=broad-except
                self.log.exception("Failed to import: %s", filepath)
                self.import_errors[filepath] = str(e)
        return mods

    def _process_modules(self, filepath, mods, file_last_changed_on_disk):
        from airflow.models.dag import DAG  # Avoid circular import

        is_zipfile = zipfile.is_zipfile(filepath)
        top_level_dags = [
            o for m in mods for o in list(m.__dict__.values())
            if isinstance(o, DAG)
        ]

        found_dags = []

        for dag in top_level_dags:
            if not dag.full_filepath:
                dag.full_filepath = filepath
                if dag.fileloc != filepath and not is_zipfile:
                    dag.fileloc = filepath
            try:
                dag.is_subdag = False
                self.bag_dag(dag=dag, root_dag=dag)
                if isinstance(dag.normalized_schedule_interval, str):
                    croniter(dag.normalized_schedule_interval)
                found_dags.append(dag)
                found_dags += dag.subdags
            except (CroniterBadCronError, CroniterBadDateError,
                    CroniterNotAlphaError) as cron_e:
                self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                self.import_errors[
                    dag.full_filepath] = f"Invalid Cron expression: {cron_e}"
                self.file_last_changed[dag.full_filepath] = \
                    file_last_changed_on_disk
            except (AirflowDagCycleException,
                    AirflowClusterPolicyViolation) as exception:
                self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                self.import_errors[dag.full_filepath] = str(exception)
                self.file_last_changed[
                    dag.full_filepath] = file_last_changed_on_disk
        return found_dags

    def bag_dag(self, dag, root_dag):
        """
        Adds the DAG into the bag, recurses into sub dags.
        Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
        """
        test_cycle(dag)  # throws if a task cycle is found

        dag.resolve_template_files()
        dag.last_loaded = timezone.utcnow()

        for task in dag.tasks:
            settings.policy(task)

        subdags = dag.subdags

        try:
            for subdag in subdags:
                subdag.full_filepath = dag.full_filepath
                subdag.parent_dag = dag
                subdag.is_subdag = True
                self.bag_dag(dag=subdag, root_dag=root_dag)

            self.dags[dag.dag_id] = dag
            self.log.debug('Loaded DAG %s', dag)
        except AirflowDagCycleException as cycle_exception:
            # There was an error in bagging the dag. Remove it from the list of dags
            self.log.exception('Exception bagging dag: %s', dag.dag_id)
            # Only necessary at the root level since DAG.subdags automatically
            # performs DFS to search through all subdags
            if dag == root_dag:
                for subdag in subdags:
                    if subdag.dag_id in self.dags:
                        del self.dags[subdag.dag_id]
            raise cycle_exception

    def collect_dags(self,
                     dag_folder=None,
                     only_if_updated=True,
                     include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
                     include_smart_sensor=conf.getboolean(
                         'smart_sensor', 'USE_SMART_SENSOR'),
                     safe_mode=conf.getboolean('core',
                                               'DAG_DISCOVERY_SAFE_MODE')):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        if self.read_dags_from_db:
            return

        self.log.info("Filling up the DagBag from %s", dag_folder)
        start_dttm = timezone.utcnow()
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []

        dag_folder = correct_maybe_zipped(dag_folder)
        for filepath in list_py_file_paths(
                dag_folder,
                safe_mode=safe_mode,
                include_examples=include_examples,
                include_smart_sensor=include_smart_sensor):
            try:
                file_parse_start_dttm = timezone.utcnow()
                found_dags = self.process_file(filepath,
                                               only_if_updated=only_if_updated,
                                               safe_mode=safe_mode)

                file_parse_end_dttm = timezone.utcnow()
                stats.append(
                    FileLoadStat(
                        file=filepath.replace(settings.DAGS_FOLDER, ''),
                        duration=file_parse_end_dttm - file_parse_start_dttm,
                        dag_num=len(found_dags),
                        task_num=sum([len(dag.tasks) for dag in found_dags]),
                        dags=str([dag.dag_id for dag in found_dags]),
                    ))
            except Exception as e:  # pylint: disable=broad-except
                self.log.exception(e)

        end_dttm = timezone.utcnow()
        durations = (end_dttm - start_dttm).total_seconds()
        Stats.gauge('collect_dags', durations, 1)
        Stats.gauge('dagbag_size', len(self.dags), 1)
        Stats.gauge('dagbag_import_errors', len(self.import_errors), 1)
        self.dagbag_stats = sorted(stats,
                                   key=lambda x: x.duration,
                                   reverse=True)
        for file_stat in self.dagbag_stats:
            # file_stat.file similar format: /subdir/dag_name.py
            # TODO: Remove for Airflow 2.0
            filename = file_stat.file.split('/')[-1].replace('.py', '')
            Stats.timing('dag.loading-duration.{}'.format(filename),
                         file_stat.duration)

    def collect_dags_from_db(self):
        """Collects DAGs from database."""
        from airflow.models.serialized_dag import SerializedDagModel
        start_dttm = timezone.utcnow()
        self.log.info("Filling up the DagBag from database")

        # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
        # from the table by the scheduler job.
        self.dags = SerializedDagModel.read_all_dags()

        # Adds subdags.
        # DAG post-processing steps such as self.bag_dag and croniter are not needed as
        # they are done by scheduler before serialization.
        subdags = {}
        for dag in self.dags.values():
            for subdag in dag.subdags:
                subdags[subdag.dag_id] = subdag
        self.dags.update(subdags)

        Stats.timing('collect_db_dags', timezone.utcnow() - start_dttm)

    def dagbag_report(self):
        """Prints a report around DagBag loading stats"""
        stats = self.dagbag_stats
        dag_folder = self.dag_folder
        duration = sum([o.duration for o in stats],
                       timedelta()).total_seconds()
        dag_num = sum([o.dag_num for o in stats])
        task_num = sum([o.task_num for o in stats])
        table = tabulate(stats, headers="keys")

        report = textwrap.dedent(f"""\n
        -------------------------------------------------------------------
        DagBag loading stats for {dag_folder}
        -------------------------------------------------------------------
        Number of DAGs: {dag_num}
        Total task number: {task_num}
        DagBag parsing time: {duration}
        {table}
        """)
        return report

    def sync_to_db(self):
        """
        Save attributes about list of DAG to the DB.
        """
        # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
        from airflow.models.dag import DAG
        from airflow.models.serialized_dag import SerializedDagModel
        self.log.debug("Calling the DAG.bulk_sync_to_db method")
        DAG.bulk_sync_to_db(self.dags.values())
        # Write Serialized DAGs to DB if DAG Serialization is turned on
        # Even though self.read_dags_from_db is False
        if settings.STORE_SERIALIZED_DAGS:
            self.log.debug(
                "Calling the SerializedDagModel.bulk_sync_to_db method")
            SerializedDagModel.bulk_sync_to_db(self.dags.values())
Exemple #41
0
    def __init__(
        self,
        dag_directory: str,
        max_runs: int,
        processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess],
        processor_timeout: timedelta,
        signal_conn: MultiprocessingConnection,
        dag_ids: Optional[List[str]],
        pickle_dags: bool,
        async_mode: bool = True,
    ):
        super().__init__()
        self._file_paths: List[str] = []
        self._file_path_queue: List[str] = []
        self._dag_directory = dag_directory
        self._max_runs = max_runs
        self._processor_factory = processor_factory
        self._signal_conn = signal_conn
        self._pickle_dags = pickle_dags
        self._dag_ids = dag_ids
        self._async_mode = async_mode
        self._parsing_start_time: Optional[datetime] = None

        self._parallelism = conf.getint('scheduler', 'max_threads')
        if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1:
            self.log.warning(
                "Because we cannot use more than 1 thread (max_threads = "
                "%d ) when using sqlite. So we set parallelism to 1.",
                self._parallelism,
            )
            self._parallelism = 1

        # Parse and schedule each file no faster than this interval.
        self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval')
        # How often to print out DAG file processing stats to the log. Default to
        # 30 seconds.
        self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval')
        # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
        self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

        # Should store dag file source in a database?
        self.store_dag_code = STORE_DAG_CODE
        # Map from file path to the processor
        self._processors: Dict[str, AbstractDagFileProcessorProcess] = {}

        self._num_run = 0

        # Map from file path to stats about the file
        self._file_stats: Dict[str, DagFileStat] = {}

        self._last_zombie_query_time = None
        # Last time that the DAG dir was traversed to look for files
        self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0))
        # Last time stats were printed
        self.last_stat_print_time = timezone.datetime(2000, 1, 1)
        # TODO: Remove magic number
        self._zombie_query_interval = 10
        # How long to wait before timing out a process to parse a DAG file
        self._processor_timeout = processor_timeout

        # How often to scan the DAGs directory for new files. Default to 5 minutes.
        self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval')

        # Mapping file name and callbacks requests
        self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list)

        self._log = logging.getLogger('airflow.processor_manager')

        self.waitables: Dict[Any, Union[MultiprocessingConnection, AbstractDagFileProcessorProcess]] = {
            self._signal_conn: self._signal_conn,
        }
Exemple #42
0
 def is_alive(self):
     return (
         (datetime.now() - self.latest_heartbeat).seconds <
         (conf.getint('scheduler', 'JOB_HEARTBEAT_SEC') * 2.1)
     )
Exemple #43
0
 def is_alive(self):
     return (datetime.now() - self.latest_heartbeat).seconds < (conf.getint("scheduler", "JOB_HEARTBEAT_SEC") * 2.1)
Exemple #44
0
    def __init__(
            self,
            task_id: str,
            owner: str = conf.get('operators', 'DEFAULT_OWNER'),
            email: Optional[str] = None,
            email_on_retry: bool = True,
            email_on_failure: bool = True,
            retries: Optional[int] = None,
            retry_delay: timedelta = timedelta(seconds=300),
            retry_exponential_backoff: bool = False,
            max_retry_delay: Optional[datetime] = None,
            start_date: Optional[datetime] = None,
            end_date: Optional[datetime] = None,
            depends_on_past: bool = False,
            wait_for_downstream: bool = False,
            dag: Optional[DAG] = None,
            params: Optional[Dict] = None,
            default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
            priority_weight: int = 1,
            weight_rule: str = WeightRule.DOWNSTREAM,
            queue: str = conf.get('celery', 'default_queue'),
            pool: str = Pool.DEFAULT_POOL_NAME,
            sla: Optional[timedelta] = None,
            execution_timeout: Optional[timedelta] = None,
            on_failure_callback: Optional[Callable] = None,
            on_success_callback: Optional[Callable] = None,
            on_retry_callback: Optional[Callable] = None,
            trigger_rule: str = TriggerRule.ALL_SUCCESS,
            resources: Optional[Dict] = None,
            run_as_user: Optional[str] = None,
            task_concurrency: Optional[int] = None,
            executor_config: Optional[Dict] = None,
            do_xcom_push: bool = True,
            inlets: Optional[Dict] = None,
            outlets: Optional[Dict] = None,
            *args,
            **kwargs):

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries if retries is not None else \
            conf.getint('core', 'default_task_retries', fallback=0)
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(
            *resources) if resources is not None else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: List[DataSet]
        self.outlets = []  # type: List[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
from builtins import range
from builtins import object
import logging

from airflow.utils import State
from airflow.configuration import conf

PARALLELISM = conf.getint('core', 'PARALLELISM')


class BaseExecutor(object):

    def __init__(self, parallelism=PARALLELISM):
        """
        Class to derive in order to interface with executor-type systems
        like Celery, Mesos, Yarn and the likes.

        :param parallelism: how many jobs should run at one time. Set to
            ``0`` for infinity
        :type parallelism: int
        """
        self.parallelism = parallelism
        self.queued_tasks = {}
        self.running = {}
        self.event_buffer = {}

    def start(self):  # pragma: no cover
        """
        Executors may need to get things started. For example LocalExecutor
        starts N workers.
        """
Exemple #46
0
class DagBag(LoggingMixin):
    """
    A dagbag is a collection of dags, parsed out of a folder tree and has high
    level configuration settings, like what database to use as a backend and
    what executor to use to fire off tasks. This makes it easier to run
    distinct environments for say production and development, tests, or for
    different teams or security profiles. What would have been system level
    settings are now dagbag level so that one system can run multiple,
    independent settings sets.

    :param dag_folder: the folder to scan to find DAGs
    :type dag_folder: unicode
    :param include_examples: whether to include the examples that ship
        with airflow or not
    :type include_examples: bool
    :param include_smart_sensor: whether to include the smart sensor native
        DAGs that create the smart sensor operators for whole cluster
    :type include_smart_sensor: bool
    :param read_dags_from_db: Read DAGs from DB if ``True`` is passed.
        If ``False`` DAGs are read from python files.
    :type read_dags_from_db: bool
    :param load_op_links: Should the extra operator link be loaded via plugins when
        de-serializing the DAG? This flag is set to False in Scheduler so that Extra Operator links
        are not loaded to not run User code in Scheduler.
    :type load_op_links: bool
    """

    DAGBAG_IMPORT_TIMEOUT = conf.getfloat('core', 'DAGBAG_IMPORT_TIMEOUT')
    SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint(
        'scheduler', 'scheduler_zombie_task_threshold')

    def __init__(
        self,
        dag_folder: Union[str, "pathlib.Path", None] = None,
        include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
        include_smart_sensor: bool = conf.getboolean('smart_sensor',
                                                     'USE_SMART_SENSOR'),
        safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
        read_dags_from_db: bool = False,
        store_serialized_dags: Optional[bool] = None,
        load_op_links: bool = True,
    ):
        # Avoid circular import
        from airflow.models.dag import DAG

        super().__init__()

        if store_serialized_dags:
            warnings.warn(
                "The store_serialized_dags parameter has been deprecated. "
                "You should pass the read_dags_from_db parameter.",
                DeprecationWarning,
                stacklevel=2,
            )
            read_dags_from_db = store_serialized_dags

        dag_folder = dag_folder or settings.DAGS_FOLDER
        self.dag_folder = dag_folder
        self.dags: Dict[str, DAG] = {}
        # the file's last modified timestamp when we last read it
        self.file_last_changed: Dict[str, datetime] = {}
        self.import_errors: Dict[str, str] = {}
        self.has_logged = False
        self.read_dags_from_db = read_dags_from_db
        # Only used by read_dags_from_db=True
        self.dags_last_fetched: Dict[str, datetime] = {}
        # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs
        self.dags_hash: Dict[str, str] = {}

        self.dagbag_import_error_tracebacks = conf.getboolean(
            'core', 'dagbag_import_error_tracebacks')
        self.dagbag_import_error_traceback_depth = conf.getint(
            'core', 'dagbag_import_error_traceback_depth')
        self.collect_dags(
            dag_folder=dag_folder,
            include_examples=include_examples,
            include_smart_sensor=include_smart_sensor,
            safe_mode=safe_mode,
        )
        # Should the extra operator link be loaded via plugins?
        # This flag is set to False in Scheduler so that Extra Operator links are not loaded
        self.load_op_links = load_op_links

    def size(self) -> int:
        """:return: the amount of dags contained in this dagbag"""
        return len(self.dags)

    @property
    def store_serialized_dags(self) -> bool:
        """Whether or not to read dags from DB"""
        warnings.warn(
            "The store_serialized_dags property has been deprecated. Use read_dags_from_db instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.read_dags_from_db

    @property
    def dag_ids(self) -> List[str]:
        """
        :return: a list of DAG IDs in this bag
        :rtype: List[unicode]
        """
        return list(self.dags.keys())

    @provide_session
    def get_dag(self, dag_id, session: Session = None):
        """
        Gets the DAG out of the dictionary, and refreshes it if expired

        :param dag_id: DAG Id
        :type dag_id: str
        """
        # Avoid circular import
        from airflow.models.dag import DagModel

        if self.read_dags_from_db:
            # Import here so that serialized dag is only imported when serialization is enabled
            from airflow.models.serialized_dag import SerializedDagModel

            if dag_id not in self.dags:
                # Load from DB if not (yet) in the bag
                self._add_dag_from_db(dag_id=dag_id, session=session)
                return self.dags.get(dag_id)

            # If DAG is in the DagBag, check the following
            # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
            # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
            # 3. if (2) is yes, fetch the Serialized DAG.
            # 4. if (2) returns None (i.e. Serialized DAG is deleted), remove dag from dagbag
            # if it exists and return None.
            min_serialized_dag_fetch_secs = timedelta(
                seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
            if (dag_id in self.dags_last_fetched
                    and timezone.utcnow() > self.dags_last_fetched[dag_id] +
                    min_serialized_dag_fetch_secs):
                sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(
                    dag_id=dag_id,
                    session=session,
                )
                if not sd_last_updated_datetime:
                    self.log.warning("Serialized DAG %s no longer exists",
                                     dag_id)
                    del self.dags[dag_id]
                    del self.dags_last_fetched[dag_id]
                    del self.dags_hash[dag_id]
                    return None

                if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
                    self._add_dag_from_db(dag_id=dag_id, session=session)

            return self.dags.get(dag_id)

        # If asking for a known subdag, we want to refresh the parent
        dag = None
        root_dag_id = dag_id
        if dag_id in self.dags:
            dag = self.dags[dag_id]
            if dag.is_subdag:
                root_dag_id = dag.parent_dag.dag_id  # type: ignore

        # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized?
        orm_dag = DagModel.get_current(root_dag_id, session=session)
        if not orm_dag:
            return self.dags.get(dag_id)

        # If the dag corresponding to root_dag_id is absent or expired
        is_missing = root_dag_id not in self.dags
        is_expired = orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired
        if is_expired:
            # Remove associated dags so we can re-add them.
            self.dags = {
                key: dag
                for key, dag in self.dags.items()
                if root_dag_id != key and not (
                    dag.is_subdag and root_dag_id == dag.parent_dag.dag_id)
            }
        if is_missing or is_expired:
            # Reprocess source file.
            found_dags = self.process_file(filepath=correct_maybe_zipped(
                orm_dag.fileloc),
                                           only_if_updated=False)

            # If the source file no longer exports `dag_id`, delete it from self.dags
            if found_dags and dag_id in [
                    found_dag.dag_id for found_dag in found_dags
            ]:
                return self.dags[dag_id]
            elif dag_id in self.dags:
                del self.dags[dag_id]
        return self.dags.get(dag_id)

    def _add_dag_from_db(self, dag_id: str, session: Session):
        """Add DAG to DagBag from DB"""
        from airflow.models.serialized_dag import SerializedDagModel

        row = SerializedDagModel.get(dag_id, session)
        if not row:
            raise SerializedDagNotFound(
                f"DAG '{dag_id}' not found in serialized_dag table")

        row.load_op_links = self.load_op_links
        dag = row.dag
        for subdag in dag.subdags:
            self.dags[subdag.dag_id] = subdag
        self.dags[dag.dag_id] = dag
        self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
        self.dags_hash[dag.dag_id] = row.dag_hash

    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return []

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(
                os.path.getmtime(filepath))
            if (only_if_updated and filepath in self.file_last_changed
                    and file_last_changed_on_disk
                    == self.file_last_changed[filepath]):
                return []
        except Exception as e:
            self.log.exception(e)
            return []

        if not zipfile.is_zipfile(filepath):
            mods = self._load_modules_from_file(filepath, safe_mode)
        else:
            mods = self._load_modules_from_zip(filepath, safe_mode)

        found_dags = self._process_modules(filepath, mods,
                                           file_last_changed_on_disk)

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags

    def _load_modules_from_file(self, filepath, safe_mode):
        if not might_contain_dag(filepath, safe_mode):
            # Don't want to spam user with skip messages
            if not self.has_logged:
                self.has_logged = True
                self.log.info("File %s assumed to contain no DAGs. Skipping.",
                              filepath)
            return []

        self.log.debug("Importing %s", filepath)
        org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
        path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest()
        mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}'

        if mod_name in sys.modules:
            del sys.modules[mod_name]

        timeout_msg = f"DagBag import timeout for {filepath} after {self.DAGBAG_IMPORT_TIMEOUT}s"
        with timeout(self.DAGBAG_IMPORT_TIMEOUT, error_message=timeout_msg):
            try:
                loader = importlib.machinery.SourceFileLoader(
                    mod_name, filepath)
                spec = importlib.util.spec_from_loader(mod_name, loader)
                new_module = importlib.util.module_from_spec(spec)
                sys.modules[spec.name] = new_module
                loader.exec_module(new_module)
                return [new_module]
            except Exception as e:
                self.log.exception("Failed to import: %s", filepath)
                if self.dagbag_import_error_tracebacks:
                    self.import_errors[filepath] = traceback.format_exc(
                        limit=-self.dagbag_import_error_traceback_depth)
                else:
                    self.import_errors[filepath] = str(e)
        return []

    def _load_modules_from_zip(self, filepath, safe_mode):
        mods = []
        with zipfile.ZipFile(filepath) as current_zip_file:
            for zip_info in current_zip_file.infolist():
                head, _ = os.path.split(zip_info.filename)
                mod_name, ext = os.path.splitext(zip_info.filename)
                if ext not in [".py", ".pyc"]:
                    continue
                if head:
                    continue

                if mod_name == '__init__':
                    self.log.warning("Found __init__.%s at root of %s", ext,
                                     filepath)

                self.log.debug("Reading %s from %s", zip_info.filename,
                               filepath)

                if not might_contain_dag(zip_info.filename, safe_mode,
                                         current_zip_file):
                    # todo: create ignore list
                    # Don't want to spam user with skip messages
                    if not self.has_logged:
                        self.has_logged = True
                        self.log.info(
                            "File %s:%s assumed to contain no DAGs. Skipping.",
                            filepath, zip_info.filename)
                    continue

                if mod_name in sys.modules:
                    del sys.modules[mod_name]

                try:
                    sys.path.insert(0, filepath)
                    current_module = importlib.import_module(mod_name)
                    mods.append(current_module)
                except Exception as e:
                    fileloc = os.path.join(filepath, zip_info.filename)
                    self.log.exception("Failed to import: %s", fileloc)
                    if self.dagbag_import_error_tracebacks:
                        self.import_errors[fileloc] = traceback.format_exc(
                            limit=-self.dagbag_import_error_traceback_depth)
                    else:
                        self.import_errors[fileloc] = str(e)
        return mods

    def _process_modules(self, filepath, mods, file_last_changed_on_disk):
        from airflow.models.dag import DAG  # Avoid circular import

        top_level_dags = ((o, m) for m in mods for o in m.__dict__.values()
                          if isinstance(o, DAG))

        found_dags = []

        for (dag, mod) in top_level_dags:
            dag.fileloc = mod.__file__
            try:
                dag.is_subdag = False
                dag.timetable.validate()
                self.bag_dag(dag=dag, root_dag=dag)
                found_dags.append(dag)
                found_dags += dag.subdags
            except AirflowTimetableInvalid as exception:
                self.log.exception("Failed to bag_dag: %s", dag.fileloc)
                self.import_errors[
                    dag.fileloc] = f"Invalid timetable expression: {exception}"
                self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
            except (
                    AirflowDagCycleException,
                    AirflowDagDuplicatedIdException,
                    AirflowClusterPolicyViolation,
            ) as exception:
                self.log.exception("Failed to bag_dag: %s", dag.fileloc)
                self.import_errors[dag.fileloc] = str(exception)
                self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
        return found_dags

    def bag_dag(self, dag, root_dag):
        """
        Adds the DAG into the bag, recurses into sub dags.

        :raises: AirflowDagCycleException if a cycle is detected in this dag or its subdags.
        :raises: AirflowDagDuplicatedIdException if this dag or its subdags already exists in the bag.
        """
        self._bag_dag(dag=dag, root_dag=root_dag, recursive=True)

    def _bag_dag(self, *, dag, root_dag, recursive):
        """Actual implementation of bagging a dag.

        The only purpose of this is to avoid exposing ``recursive`` in ``bag_dag()``,
        intended to only be used by the ``_bag_dag()`` implementation.
        """
        check_cycle(dag)  # throws if a task cycle is found

        dag.resolve_template_files()
        dag.last_loaded = timezone.utcnow()

        # Check policies
        settings.dag_policy(dag)

        for task in dag.tasks:
            settings.task_policy(task)

        subdags = dag.subdags

        try:
            # DAG.subdags automatically performs DFS search, so we don't recurse
            # into further _bag_dag() calls.
            if recursive:
                for subdag in subdags:
                    subdag.fileloc = dag.fileloc
                    subdag.parent_dag = dag
                    subdag.is_subdag = True
                    self._bag_dag(dag=subdag,
                                  root_dag=root_dag,
                                  recursive=False)

            prev_dag = self.dags.get(dag.dag_id)
            if prev_dag and prev_dag.fileloc != dag.fileloc:
                raise AirflowDagDuplicatedIdException(
                    dag_id=dag.dag_id,
                    incoming=dag.fileloc,
                    existing=self.dags[dag.dag_id].fileloc,
                )
            self.dags[dag.dag_id] = dag
            self.log.debug('Loaded DAG %s', dag)
        except (AirflowDagCycleException, AirflowDagDuplicatedIdException):
            # There was an error in bagging the dag. Remove it from the list of dags
            self.log.exception('Exception bagging dag: %s', dag.dag_id)
            # Only necessary at the root level since DAG.subdags automatically
            # performs DFS to search through all subdags
            if recursive:
                for subdag in subdags:
                    if subdag.dag_id in self.dags:
                        del self.dags[subdag.dag_id]
            raise

    def collect_dags(
        self,
        dag_folder: Union[str, "pathlib.Path", None] = None,
        only_if_updated: bool = True,
        include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
        include_smart_sensor: bool = conf.getboolean('smart_sensor',
                                                     'USE_SMART_SENSOR'),
        safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
    ):
        """
        Given a file path or a folder, this method looks for python modules,
        imports them and adds them to the dagbag collection.

        Note that if a ``.airflowignore`` file is found while processing
        the directory, it will behave much like a ``.gitignore``,
        ignoring files that match any of the regex patterns specified
        in the file.

        **Note**: The patterns in .airflowignore are treated as
        un-anchored regexes, not shell-like glob patterns.
        """
        if self.read_dags_from_db:
            return

        self.log.info("Filling up the DagBag from %s", dag_folder)
        dag_folder = dag_folder or self.dag_folder
        # Used to store stats around DagBag processing
        stats = []

        # Ensure dag_folder is a str -- it may have been a pathlib.Path
        dag_folder = correct_maybe_zipped(str(dag_folder))
        for filepath in list_py_file_paths(
                dag_folder,
                safe_mode=safe_mode,
                include_examples=include_examples,
                include_smart_sensor=include_smart_sensor,
        ):
            try:
                file_parse_start_dttm = timezone.utcnow()
                found_dags = self.process_file(filepath,
                                               only_if_updated=only_if_updated,
                                               safe_mode=safe_mode)

                file_parse_end_dttm = timezone.utcnow()
                stats.append(
                    FileLoadStat(
                        file=filepath.replace(settings.DAGS_FOLDER, ''),
                        duration=file_parse_end_dttm - file_parse_start_dttm,
                        dag_num=len(found_dags),
                        task_num=sum(len(dag.tasks) for dag in found_dags),
                        dags=str([dag.dag_id for dag in found_dags]),
                    ))
            except Exception as e:
                self.log.exception(e)

        self.dagbag_stats = sorted(stats,
                                   key=lambda x: x.duration,
                                   reverse=True)

    def collect_dags_from_db(self):
        """Collects DAGs from database."""
        from airflow.models.serialized_dag import SerializedDagModel

        with Stats.timer('collect_db_dags'):
            self.log.info("Filling up the DagBag from database")

            # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
            # from the table by the scheduler job.
            self.dags = SerializedDagModel.read_all_dags()

            # Adds subdags.
            # DAG post-processing steps such as self.bag_dag and croniter are not needed as
            # they are done by scheduler before serialization.
            subdags = {}
            for dag in self.dags.values():
                for subdag in dag.subdags:
                    subdags[subdag.dag_id] = subdag
            self.dags.update(subdags)

    def dagbag_report(self):
        """Prints a report around DagBag loading stats"""
        stats = self.dagbag_stats
        dag_folder = self.dag_folder
        duration = sum((o.duration for o in stats),
                       timedelta()).total_seconds()
        dag_num = sum(o.dag_num for o in stats)
        task_num = sum(o.task_num for o in stats)
        table = tabulate(stats, headers="keys")

        report = textwrap.dedent(f"""\n
        -------------------------------------------------------------------
        DagBag loading stats for {dag_folder}
        -------------------------------------------------------------------
        Number of DAGs: {dag_num}
        Total task number: {task_num}
        DagBag parsing time: {duration}
        {table}
        """)
        return report

    @provide_session
    def sync_to_db(self, session: Optional[Session] = None):
        """Save attributes about list of DAG to the DB."""
        # To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
        from airflow.models.dag import DAG
        from airflow.models.serialized_dag import SerializedDagModel

        def _serialize_dag_capturing_errors(dag, session):
            """
            Try to serialize the dag to the DB, but make a note of any errors.

            We can't place them directly in import_errors, as this may be retried, and work the next time
            """
            if dag.is_subdag:
                return []
            try:
                # We can't use bulk_write_to_db as we want to capture each error individually
                dag_was_updated = SerializedDagModel.write_dag(
                    dag,
                    min_update_interval=settings.
                    MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
                    session=session,
                )
                if dag_was_updated:
                    self._sync_perm_for_dag(dag, session=session)
                return []
            except OperationalError:
                raise
            except Exception:
                self.log.exception("Failed to write serialized DAG: %s",
                                   dag.full_filepath)
                return [(dag.fileloc,
                         traceback.format_exc(
                             limit=-self.dagbag_import_error_traceback_depth))]

        # Retry 'DAG.bulk_write_to_db' & 'SerializedDagModel.bulk_sync_to_db' in case
        # of any Operational Errors
        # In case of failures, provide_session handles rollback
        for attempt in run_with_db_retries(logger=self.log):
            with attempt:
                serialize_errors = []
                self.log.debug(
                    "Running dagbag.sync_to_db with retries. Try %d of %d",
                    attempt.retry_state.attempt_number,
                    MAX_DB_RETRIES,
                )
                self.log.debug("Calling the DAG.bulk_sync_to_db method")
                try:
                    # Write Serialized DAGs to DB, capturing errors
                    for dag in self.dags.values():
                        serialize_errors.extend(
                            _serialize_dag_capturing_errors(dag, session))

                    DAG.bulk_write_to_db(self.dags.values(), session=session)
                except OperationalError:
                    session.rollback()
                    raise
                # Only now we are "complete" do we update import_errors - don't want to record errors from
                # previous failed attempts
                self.import_errors.update(dict(serialize_errors))

    @provide_session
    def _sync_perm_for_dag(self, dag, session: Optional[Session] = None):
        """Sync DAG specific permissions, if necessary"""
        from flask_appbuilder.security.sqla import models as sqla_models

        from airflow.security.permissions import DAG_ACTIONS, resource_name_for_dag

        def needs_perm_views(dag_id: str) -> bool:
            dag_resource_name = resource_name_for_dag(dag_id)
            for permission_name in DAG_ACTIONS:
                if not (session.query(sqla_models.PermissionView).join(
                        sqla_models.Permission).join(
                            sqla_models.ViewMenu).filter(
                                sqla_models.Permission.name == permission_name
                            ).filter(sqla_models.ViewMenu.name ==
                                     dag_resource_name).one_or_none()):
                    return True
            return False

        if dag.access_control or needs_perm_views(dag.dag_id):
            self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id)
            from airflow.www.security import ApplessAirflowSecurityManager

            security_manager = ApplessAirflowSecurityManager(session=session)
            security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control)
    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
        """
        Given a path to a python module or zip file, this method imports
        the module and look for dag objects within it.
        """
        from airflow.models.dag import DAG  # Avoid circular import

        found_dags = []

        # if the source file no longer exists in the DB or in the filesystem,
        # return an empty list
        # todo: raise exception?
        if filepath is None or not os.path.isfile(filepath):
            return found_dags

        try:
            # This failed before in what may have been a git sync
            # race condition
            file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
            if only_if_updated \
                    and filepath in self.file_last_changed \
                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
                return found_dags

        except Exception as e:
            self.log.exception(e)
            return found_dags

        mods = []
        is_zipfile = zipfile.is_zipfile(filepath)
        if not is_zipfile:
            if safe_mode:
                with open(filepath, 'rb') as file:
                    content = file.read()
                    if not all([s in content for s in (b'DAG', b'airflow')]):
                        self.file_last_changed[filepath] = file_last_changed_on_disk
                        # Don't want to spam user with skip messages
                        if not self.has_logged:
                            self.has_logged = True
                            self.log.info(
                                "File %s assumed to contain no DAGs. Skipping.",
                                filepath)
                        return found_dags

            self.log.debug("Importing %s", filepath)
            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
            mod_name = ('unusual_prefix_' +
                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
                        '_' + org_mod_name)

            if mod_name in sys.modules:
                del sys.modules[mod_name]

            with timeout(conf.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
                try:
                    m = imp.load_source(mod_name, filepath)
                    mods.append(m)
                except Exception as e:
                    self.log.exception("Failed to import: %s", filepath)
                    self.import_errors[filepath] = str(e)
                    self.file_last_changed[filepath] = file_last_changed_on_disk

        else:
            zip_file = zipfile.ZipFile(filepath)
            for mod in zip_file.infolist():
                head, _ = os.path.split(mod.filename)
                mod_name, ext = os.path.splitext(mod.filename)
                if not head and (ext == '.py' or ext == '.pyc'):
                    if mod_name == '__init__':
                        self.log.warning("Found __init__.%s at root of %s", ext, filepath)
                    if safe_mode:
                        with zip_file.open(mod.filename) as zf:
                            self.log.debug("Reading %s from %s", mod.filename, filepath)
                            content = zf.read()
                            if not all([s in content for s in (b'DAG', b'airflow')]):
                                self.file_last_changed[filepath] = (
                                    file_last_changed_on_disk)
                                # todo: create ignore list
                                # Don't want to spam user with skip messages
                                if not self.has_logged:
                                    self.has_logged = True
                                    self.log.info(
                                        "File %s assumed to contain no DAGs. Skipping.",
                                        filepath)

                    if mod_name in sys.modules:
                        del sys.modules[mod_name]

                    try:
                        sys.path.insert(0, filepath)
                        m = importlib.import_module(mod_name)
                        mods.append(m)
                    except Exception as e:
                        self.log.exception("Failed to import: %s", filepath)
                        self.import_errors[filepath] = str(e)
                        self.file_last_changed[filepath] = file_last_changed_on_disk

        for m in mods:
            for dag in list(m.__dict__.values()):
                if isinstance(dag, DAG):
                    if not dag.full_filepath:
                        dag.full_filepath = filepath
                        if dag.fileloc != filepath and not is_zipfile:
                            dag.fileloc = filepath
                    try:
                        dag.is_subdag = False
                        self.bag_dag(dag, parent_dag=dag, root_dag=dag)
                        if isinstance(dag._schedule_interval, str):
                            croniter(dag._schedule_interval)
                        found_dags.append(dag)
                        found_dags += dag.subdags
                    except (CroniterBadCronError,
                            CroniterBadDateError,
                            CroniterNotAlphaError) as cron_e:
                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                        self.import_errors[dag.full_filepath] = \
                            "Invalid Cron expression: " + str(cron_e)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk
                    except AirflowDagCycleException as cycle_exception:
                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
                        self.import_errors[dag.full_filepath] = str(cycle_exception)
                        self.file_last_changed[dag.full_filepath] = \
                            file_last_changed_on_disk

        self.file_last_changed[filepath] = file_last_changed_on_disk
        return found_dags
Exemple #48
0
    def _execute(self):
        self.task_runner = get_task_runner(self)

        def signal_handler(signum, frame):
            """Setting kill signal handler"""
            self.log.error("Received SIGTERM. Terminating subprocesses")
            self.task_runner.terminate()
            self.handle_task_exit(128 + signum)
            return

        signal.signal(signal.SIGTERM, signal_handler)

        if not self.task_instance.check_and_change_state_before_execution(
                mark_success=self.mark_success,
                ignore_all_deps=self.ignore_all_deps,
                ignore_depends_on_past=self.ignore_depends_on_past,
                ignore_task_deps=self.ignore_task_deps,
                ignore_ti_state=self.ignore_ti_state,
                job_id=self.id,
                pool=self.pool,
                external_executor_id=self.external_executor_id,
        ):
            self.log.info("Task is not able to be run")
            return

        try:
            self.task_runner.start()

            heartbeat_time_limit = conf.getint(
                'scheduler', 'scheduler_zombie_task_threshold')

            # task callback invocation happens either here or in
            # self.heartbeat() instead of taskinstance._run_raw_task to
            # avoid race conditions
            #
            # When self.terminating is set to True by heartbeat_callback, this
            # loop should not be restarted. Otherwise self.handle_task_exit
            # will be invoked and we will end up with duplicated callbacks
            while not self.terminating:
                # Monitor the task to see if it's done. Wait in a syscall
                # (`os.wait`) for as long as possible so we notice the
                # subprocess finishing as quick as we can
                max_wait_time = max(
                    0,  # Make sure this value is never negative,
                    min(
                        (heartbeat_time_limit -
                         (timezone.utcnow() -
                          self.latest_heartbeat).total_seconds() * 0.75),
                        self.heartrate,
                    ),
                )

                return_code = self.task_runner.return_code(
                    timeout=max_wait_time)
                if return_code is not None:
                    self.handle_task_exit(return_code)
                    return

                self.heartbeat()

                # If it's been too long since we've heartbeat, then it's possible that
                # the scheduler rescheduled this task, so kill launched processes.
                # This can only really happen if the worker can't read the DB for a long time
                time_since_last_heartbeat = (
                    timezone.utcnow() - self.latest_heartbeat).total_seconds()
                if time_since_last_heartbeat > heartbeat_time_limit:
                    Stats.incr('local_task_job_prolonged_heartbeat_failure', 1,
                               1)
                    self.log.error("Heartbeat time limit exceeded!")
                    raise AirflowException(
                        f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit "
                        f"({heartbeat_time_limit}s).")
        finally:
            self.on_kill()
def upgrade():
    """Apply Add scheduling_decision to DagRun and DAG"""
    conn = op.get_bind()
    is_sqlite = bool(conn.dialect.name == "sqlite")
    is_mssql = bool(conn.dialect.name == "mssql")
    timestamp = _get_timestamp(conn)

    if is_sqlite:
        op.execute("PRAGMA foreign_keys=off")

    with op.batch_alter_table('dag_run', schema=None) as batch_op:
        batch_op.add_column(
            sa.Column('last_scheduling_decision', timestamp, nullable=True))
        batch_op.create_index('idx_last_scheduling_decision',
                              ['last_scheduling_decision'],
                              unique=False)
        batch_op.add_column(sa.Column('dag_hash', sa.String(32),
                                      nullable=True))

    with op.batch_alter_table('dag', schema=None) as batch_op:
        batch_op.add_column(sa.Column('next_dagrun', timestamp, nullable=True))
        batch_op.add_column(
            sa.Column('next_dagrun_create_after', timestamp, nullable=True))
        # Create with nullable and no default, then ALTER to set values, to avoid table level lock
        batch_op.add_column(
            sa.Column('concurrency', sa.Integer(), nullable=True))
        batch_op.add_column(
            sa.Column('has_task_concurrency_limits',
                      sa.Boolean(),
                      nullable=True))

        batch_op.create_index('idx_next_dagrun_create_after',
                              ['next_dagrun_create_after'],
                              unique=False)

    try:
        from airflow.configuration import conf

        concurrency = conf.getint('core', 'dag_concurrency', fallback=16)
    except:  # noqa
        concurrency = 16

    # Set it to true here as it makes us take the slow/more complete path, and when it's next parsed by the
    # DagParser it will get set to correct value.

    op.execute(f"""
        UPDATE dag SET
            concurrency={concurrency},
            has_task_concurrency_limits={1 if is_sqlite or is_mssql else sa.true()}
        where concurrency IS NULL
        """)

    with op.batch_alter_table('dag', schema=None) as batch_op:
        batch_op.alter_column('concurrency',
                              type_=sa.Integer(),
                              nullable=False)
        batch_op.alter_column('has_task_concurrency_limits',
                              type_=sa.Boolean(),
                              nullable=False)

    if is_sqlite:
        op.execute("PRAGMA foreign_keys=on")
Exemple #50
0
from airflow import executors, models, settings, utils
from airflow.configuration import conf
from airflow.utils import AirflowException, State


Base = models.Base
ID_LEN = models.ID_LEN

# Setting up a statsd client if needed
statsd = None
if conf.get("scheduler", "statsd_on"):
    from statsd import StatsClient

    statsd = StatsClient(
        host=conf.get("scheduler", "statsd_host"),
        port=conf.getint("scheduler", "statsd_port"),
        prefix=conf.get("scheduler", "statsd_prefix"),
    )


class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """

    __tablename__ = "job"

    id = Column(Integer, primary_key=True)