示例#1
0
def send_email(to, subject, html_content):
    SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
    SMTP_MAIL_FROM = conf.get('smtp', 'SMTP_MAIL_FROM')
    SMTP_PORT = conf.get('smtp', 'SMTP_PORT')
    SMTP_USER = conf.get('smtp', 'SMTP_USER')
    SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')

    if isinstance(to, unicode) or isinstance(to, str):
        if ',' in to:
            to = to.split(',')
        elif ';' in to:
            to = to.split(';')
        else:
            to = [to]

    msg = MIMEMultipart('alternative')
    msg['Subject'] = subject
    msg['From'] = SMTP_MAIL_FROM
    msg['To'] = ", ".join(to)
    mime_text = MIMEText(html_content, 'html')
    msg.attach(mime_text)
    s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
    s.starttls()
    if SMTP_USER and SMTP_PASSWORD:
        s.login(SMTP_USER, SMTP_PASSWORD)
    logging.info("Sent an altert email to " + str(to))
    s.sendmail(SMTP_MAIL_FROM, to, msg.as_string())
    s.quit()
    def test_env_var_config(self):
        opt = conf.get('testsection', 'testkey')
        self.assertEqual(opt, 'testvalue')

        opt = conf.get('testsection', 'testpercent')
        self.assertEqual(opt, 'with%percent')

        self.assertTrue(conf.has_option('testsection', 'testkey'))
示例#3
0
def get_kube_client(in_cluster=conf.getboolean('kubernetes', 'in_cluster'),
                    cluster_context=None,
                    config_file=None):
    if not in_cluster:
        if cluster_context is None:
            cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None)
        if config_file is None:
            config_file = conf.get('kubernetes', 'config_file', fallback=None)
    return _load_kube_config(in_cluster, cluster_context, config_file)
示例#4
0
文件: cli.py 项目: johnw424/airflow
def flower(args):
    broka = conf.get('celery', 'BROKER_URL')
    args.port = args.port or conf.get('celery', 'FLOWER_PORT')
    port = '--port=' + args.port
    api = ''
    if args.broker_api:
        api = '--broker_api=' + args.broker_api
    sp = subprocess.Popen(['flower', '-b', broka, port, api])
    sp.wait()
示例#5
0
def flower(args):
    broka = conf.get("celery", "BROKER_URL")
    args.port = args.port or conf.get("celery", "FLOWER_PORT")
    port = "--port=" + args.port
    api = ""
    if args.broker_api:
        api = "--broker_api=" + args.broker_api
    sp = subprocess.Popen(["flower", "-b", broka, port, api])
    sp.wait()
示例#6
0
    def authenticate(username, password):
        service_principal = "%s/%s" % (conf.get('kerberos', 'principal'), utils.get_fqdn())
        realm = conf.get("kerberos", "default_realm")
        user_principal = utils.principal_from_username(username)

        try:
            # this is pykerberos specific, verify = True is needed to prevent KDC spoofing
            if not kerberos.checkPassword(user_principal, password, service_principal, realm, True):
                raise AuthenticationError()
        except kerberos.KrbError, e:
            logging.error('Password validation for principal %s failed %s', user_principal, e)
            raise AuthenticationError(e)
示例#7
0
def configure_vars():
    global SQL_ALCHEMY_CONN
    global DAGS_FOLDER
    global PLUGINS_FOLDER
    SQL_ALCHEMY_CONN = conf.get('core', 'SQL_ALCHEMY_CONN')
    DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER'))

    PLUGINS_FOLDER = conf.get(
        'core',
        'plugins_folder',
        fallback=os.path.join(AIRFLOW_HOME, 'plugins')
    )
示例#8
0
def send_MIME_email(e_from, e_to, mime_msg):
    SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
    SMTP_PORT = conf.get('smtp', 'SMTP_PORT')
    SMTP_USER = conf.get('smtp', 'SMTP_USER')
    SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')

    s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
    s.starttls()
    if SMTP_USER and SMTP_PASSWORD:
        s.login(SMTP_USER, SMTP_PASSWORD)
    logging.info("Sent an alert email to " + str(e_to))
    s.sendmail(e_from, e_to, mime_msg.as_string())
    s.quit()
示例#9
0
    def log(self):
        BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER')
        dag_id = request.args.get('dag_id')
        task_id = request.args.get('task_id')
        execution_date = request.args.get('execution_date')
        dag = dagbag.dags[dag_id]
        log_relative = "/{dag_id}/{task_id}/{execution_date}".format(
            **locals())
        loc = BASE_LOG_FOLDER + log_relative
        loc = loc.format(**locals())
        log = ""
        TI = models.TaskInstance
        session = Session()
        dttm = dateutil.parser.parse(execution_date)
        ti = session.query(TI).filter(
            TI.dag_id == dag_id, TI.task_id == task_id,
            TI.execution_date == dttm).first()
        if ti:
            host = ti.hostname
            if socket.gethostname() == host:
                try:
                    f = open(loc)
                    log += "".join(f.readlines())
                    f.close()
                except:
                    log = "Log file isn't where expected.\n".format(loc)
            else:
                WORKER_LOG_SERVER_PORT = \
                    conf.get('celery', 'WORKER_LOG_SERVER_PORT')
                url = (
                    "http://{host}:{WORKER_LOG_SERVER_PORT}/log"
                    "{log_relative}").format(**locals())
                log += "Log file isn't local."
                log += "Fetching here: {url}\n".format(**locals())
                try:
                    import urllib2
                    w = urllib2.urlopen(url)
                    log += w.read()
                    w.close()
                except:
                    log += "Failed to fetch log file.".format(**locals())
            session.commit()
            session.close()

        log = "<pre><code>{0}</code></pre>".format(log)
        title = "Logs for {task_id} on {execution_date}".format(**locals())
        html_code = log

        return self.render(
            'airflow/dag_code.html', html_code=html_code, dag=dag, title=title)
示例#10
0
文件: utils.py 项目: nkhuyu/airflow
def send_MIME_email(e_from, e_to, mime_msg):
    SMTP_HOST = conf.get("smtp", "SMTP_HOST")
    SMTP_PORT = conf.getint("smtp", "SMTP_PORT")
    SMTP_USER = conf.get("smtp", "SMTP_USER")
    SMTP_PASSWORD = conf.get("smtp", "SMTP_PASSWORD")
    SMTP_STARTTLS = conf.getboolean("smtp", "SMTP_STARTTLS")

    s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
    if SMTP_STARTTLS:
        s.starttls()
    if SMTP_USER and SMTP_PASSWORD:
        s.login(SMTP_USER, SMTP_PASSWORD)
    logging.info("Sent an alert email to " + str(e_to))
    s.sendmail(e_from, e_to, mime_msg.as_string())
    s.quit()
示例#11
0
文件: utils.py 项目: hoanghw/airflow
def send_MIME_email(e_from, e_to, mime_msg, dryrun=False):
    SMTP_HOST = conf.get('smtp', 'SMTP_HOST')
    SMTP_PORT = conf.getint('smtp', 'SMTP_PORT')
    SMTP_USER = conf.get('smtp', 'SMTP_USER')
    SMTP_PASSWORD = conf.get('smtp', 'SMTP_PASSWORD')
    SMTP_STARTTLS = conf.getboolean('smtp', 'SMTP_STARTTLS')

    if not dryrun:
        s = smtplib.SMTP(SMTP_HOST, SMTP_PORT)
        if SMTP_STARTTLS:
            s.starttls()
        if SMTP_USER and SMTP_PASSWORD:
            s.login(SMTP_USER, SMTP_PASSWORD)
        logging.info("Sent an alert email to " + str(e_to))
        s.sendmail(e_from, e_to, mime_msg.as_string())
        s.quit()
示例#12
0
文件: utils.py 项目: hoanghw/airflow
def send_email(to, subject, html_content, files=None, dryrun=False):
    """
    Send an email with html content

    >>> send_email('*****@*****.**', 'foo', '<b>Foo</b> bar', ['/dev/null'], dryrun=True)
    """
    SMTP_MAIL_FROM = conf.get('smtp', 'SMTP_MAIL_FROM')

    if isinstance(to, basestring):
        if ',' in to:
            to = to.split(',')
        elif ';' in to:
            to = to.split(';')
        else:
            to = [to]

    msg = MIMEMultipart('alternative')
    msg['Subject'] = subject
    msg['From'] = SMTP_MAIL_FROM
    msg['To'] = ", ".join(to)
    mime_text = MIMEText(html_content, 'html')
    msg.attach(mime_text)

    for fname in files or []:
        basename = os.path.basename(fname)
        with open(fname, "rb") as f:
            msg.attach(MIMEApplication(
                f.read(),
                Content_Disposition='attachment; filename="%s"' % basename,
                Name=basename
            ))

    send_MIME_email(SMTP_MAIL_FROM, to, msg, dryrun)
示例#13
0
def resetdb(args):
    print("DB: " + conf.get("core", "SQL_ALCHEMY_CONN"))
    if input("This will drop existing tables if they exist. " "Proceed? (y/n)").upper() == "Y":
        logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
        utils.resetdb()
    else:
        print("Bail.")
示例#14
0
def webserver(args):
    print(settings.HEADER)
    log_to_stdout()
    from airflow.www.app import app

    threads = args.threads or conf.get("webserver", "threads")
    if args.debug:
        print("Starting the web server on port {0} and host {1}.".format(args.port, args.hostname))
        app.run(debug=True, port=args.port, host=args.hostname)
    else:
        print(
            "Running the Gunicorn server with {threads}"
            "on host {args.hostname} and port "
            "{args.port}...".format(**locals())
        )
        sp = subprocess.Popen(
            [
                "gunicorn",
                "-w",
                str(args.threads),
                "-t",
                "120",
                "-b",
                args.hostname + ":" + str(args.port),
                "airflow.www.app:app",
            ]
        )
        sp.wait()
示例#15
0
 def __init__(self, host=None, db=None, port=None,
              presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID')):
     self.user = '******'
     if not presto_conn_id:
         self.host = host
         self.db = db
         self.port = port
     else:
         session = settings.Session()
         db = session.query(
             Connection).filter(
                 Connection.conn_id == presto_conn_id)
         if db.count() == 0:
             raise Exception("The presto_conn_id you provided isn't defined")
         else:
             db = db.all()[0]
         self.host = db.host
         self.db = db.schema
         self.catalog = 'hive'
         self.port = db.port
         self.cursor = presto.connect(host=db.host, port=db.port,
                                      username=self.user,
                                      catalog=self.catalog,
                                      schema=db.schema).cursor()
         session.close()    # currently only a pass in pyhive
示例#16
0
def send_email(to, subject, html_content, files=None):
    SMTP_MAIL_FROM = conf.get('smtp', 'SMTP_MAIL_FROM')

    if isinstance(to, basestring):
        if ',' in to:
            to = to.split(',')
        elif ';' in to:
            to = to.split(';')
        else:
            to = [to]

    msg = MIMEMultipart('alternative')
    msg['Subject'] = Header(subject, 'utf-8')
    msg['From'] = SMTP_MAIL_FROM
    msg['To'] = ", ".join(to)
    mime_text = MIMEText(html_content, 'html', _charset='utf-8')
    msg.attach(mime_text)

    for fname in files or []:
        basename = os.path.basename(fname)
        with open(fname, "rb") as f:
            part = MIMEApplication(f.read())
            part.add_header('Content-Disposition', 'attachment', filename=basename)
            msg.attach(part)

    send_MIME_email(SMTP_MAIL_FROM, to, msg)
示例#17
0
文件: cli.py 项目: johnw424/airflow
 def serve_logs(filename):
     log = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER'))
     return flask.send_from_directory(
         log,
         filename,
         mimetype="application/json",
         as_attachment=False)
示例#18
0
def send_email(to, subject, html_content, files=None):
    SMTP_MAIL_FROM = conf.get('smtp', 'SMTP_MAIL_FROM')

    if isinstance(to, basestring):
        if ',' in to:
            to = to.split(',')
        elif ';' in to:
            to = to.split(';')
        else:
            to = [to]

    msg = MIMEMultipart('alternative')
    msg['Subject'] = subject
    msg['From'] = SMTP_MAIL_FROM
    msg['To'] = ", ".join(to)
    mime_text = MIMEText(html_content, 'html')
    msg.attach(mime_text)

    for fname in files or []:
        basename = os.path.basename(fname)
        with open(fname, "rb") as f:
            msg.attach(MIMEApplication(
                f.read(),
                Content_Disposition='attachment; filename="%s"' % basename,
                Name=basename
            ))

    send_MIME_email(SMTP_MAIL_FROM, to, msg)
示例#19
0
文件: utils.py 项目: Raynes/airflow
def send_email(to, subject, html_content, files=None):
    SMTP_MAIL_FROM = conf.get("smtp", "SMTP_MAIL_FROM")

    if isinstance(to, basestring):
        if "," in to:
            to = to.split(",")
        elif ";" in to:
            to = to.split(";")
        else:
            to = [to]

    msg = MIMEMultipart("alternative")
    msg["Subject"] = subject
    msg["From"] = SMTP_MAIL_FROM
    msg["To"] = ", ".join(to)
    mime_text = MIMEText(html_content, "html")
    msg.attach(mime_text)

    for fname in files or []:
        basename = os.path.basename(fname)
        with open(fname, "rb") as f:
            msg.attach(
                MIMEApplication(f.read(), Content_Disposition='attachment; filename="%s"' % basename, Name=basename)
            )

    send_MIME_email(SMTP_MAIL_FROM, to, msg)
示例#20
0
文件: app.py 项目: harakiro/airflow
    def get(self):
        session = settings.Session()
        dagbag = models.DagBag(os.path.expanduser(conf.get('core', 'DAGS_FOLDER')))
        DM = models.DagModel
        qry = None

        qry = session.query(DM).filter(~DM.is_subdag, DM.is_active).all()
        orm_dags = {dag.dag_id: dag for dag in qry}
        import_errors = session.query(models.ImportError).all()

        session.expunge_all()
        session.commit()
        session.close()
        dags = dagbag.dags.values()

        dag_list = []

        for dag in dags:
            d = {
                'id': dag.dag_id,
                'schedule_interval': dag.schedule_interval.__str__(),
                'start_date': dag.start_date,
                'last_loaded': dag.last_loaded
            }
            dag_list.append(d)

        print dag_list

        #dags = {dag.dag_id: dag.dag_id for dag in dags if not dag.parent_dag}
        #all_dag_ids = sorted(set(orm_dags.keys()) | set(dags.keys()))

        #return all_dag_ids

        """
示例#21
0
    def __init__(
            self,
            cwl_workflow,
            dag_id=None,
            default_args=None,
            *args, **kwargs):

        _dag_id = dag_id if dag_id else cwl_workflow.split("/")[-1].replace(".cwl", "").replace(".", "_dot_")
        super(self.__class__, self).__init__(dag_id=_dag_id,
                                             default_args=default_args, *args, **kwargs)

        if cwl_workflow not in __cwl__tools_loaded__:
            if os.path.isabs(cwl_workflow):
                cwl_base = ""
            else:
                cwl_base = conf.get('cwl', 'CWL_HOME')

            __cwl__tools_loaded__[cwl_workflow] = cwltool.main.load_tool(os.path.join(cwl_base, cwl_workflow), False,
                                                                         False,
                                                                         cwltool.workflow.defaultMakeTool, True)

            if type(__cwl__tools_loaded__[cwl_workflow]) == int \
                    or __cwl__tools_loaded__[cwl_workflow].tool["class"] != "Workflow":
                raise cwltool.errors.WorkflowException(
                    "Class '{0}' is not supported yet in CWLDAG".format(
                        __cwl__tools_loaded__[cwl_workflow].tool["class"]))

        self.cwlwf = __cwl__tools_loaded__[cwl_workflow]
示例#22
0
def max_partition(table, schema="default", hive_conn_id=conf.get("hooks", "HIVE_DEFAULT_CONN_ID")):
    from airflow.hooks.hive_hook import HiveHook

    if "." in table:
        schema, table = table.split(".")
    hh = HiveHook(hive_conn_id=hive_conn_id)
    return hh.max_partition(schema=schema, table_name=table)
示例#23
0
文件: utils.py 项目: nkhuyu/airflow
def upgradedb():
    logging.info("Creating tables")
    package_dir = os.path.abspath(os.path.dirname(__file__))
    directory = os.path.join(package_dir, "migrations")
    config = Config(os.path.join(package_dir, "alembic.ini"))
    config.set_main_option("script_location", directory)
    config.set_main_option("sqlalchemy.url", conf.get("core", "SQL_ALCHEMY_CONN"))
    command.upgrade(config, "head")
示例#24
0
def list_dags(args):
    if args.subdir:
        subdir = args.subdir.replace(
            "DAGS_FOLDER", conf.get("core", "DAGS_FOLDER"))
        subdir = os.path.expanduser(subdir)
    dagbag = DagBag(subdir)
    # dagbag = DagBag(os.path.expanduser(conf.get('core', 'DAGS_FOLDER')))
    print("\n".join(sorted(dagbag.dags)))
示例#25
0
    def __init__(
            self, hql,
            hive_conn_id=conf.get('hooks', 'HIVE_DEFAULT_CONN_ID'),
            *args, **kwargs):
        super(HiveOperator, self).__init__(*args, **kwargs)

        self.hive_conn_id = hive_conn_id
        self.hook = HiveHook(hive_conn_id=hive_conn_id)
        self.hql = hql
示例#26
0
    def get_conn(self):
        db = self.get_connection(self.hiveserver2_conn_id)
        auth_mechanism = db.extra_dejson.get("authMechanism", "NOSASL")
        if conf.get("core", "security") == "kerberos":
            auth_mechanism = db.extra_dejson.get("authMechanism", "KERBEROS")

        return pyhs2.connect(
            host=db.host, port=db.port, authMechanism=auth_mechanism, user=db.login, database=db.schema or "default"
        )
示例#27
0
文件: utils.py 项目: john5223/airflow
def upgradedb():
    logging.info("Creating tables")
    package_dir = os.path.abspath(os.path.dirname(__file__))
    directory = os.path.join(package_dir, 'migrations')
    config = Config(os.path.join(package_dir, 'alembic.ini'))
    config.set_main_option('script_location', directory)
    config.set_main_option('sqlalchemy.url',
                           conf.get('core', 'SQL_ALCHEMY_CONN'))
    command.upgrade(config, 'head')
    def __init__(
            self, sql,
            presto_conn_id=conf.get('hooks', 'PRESTO_DEFAULT_CONN_ID'),
            *args, **kwargs):
        super(PrestoCheckOperator, self).__init__(*args, **kwargs)

        self.presto_conn_id = presto_conn_id
        self.hook = PrestoHook(presto_conn_id=presto_conn_id)
        self.sql = sql
示例#29
0
文件: cli.py 项目: DATAQC/airflow
def resetdb(args):
    print("DB: " + conf.get('core', 'SQL_ALCHEMY_CONN'))
    if raw_input(
            "This will drop existing tables if they exist. "
            "Proceed? (y/n)").upper() == "Y":
        logging.basicConfig(level=logging.DEBUG,
                            format=settings.SIMPLE_LOG_FORMAT)
        utils.resetdb()
    else:
        print("Bail.")
 def _get_environment(self):
     """Defines any necessary environment variables for the pod executor"""
     env = {
         'AIRFLOW__CORE__DAGS_FOLDER': '/tmp/dags',
         'AIRFLOW__CORE__EXECUTOR': 'LocalExecutor',
         'AIRFLOW__CORE__SQL_ALCHEMY_CONN': conf.get('core', 'SQL_ALCHEMY_CONN')
     }
     if self.kube_config.airflow_configmap:
         env['AIRFLOW__CORE__AIRFLOW_HOME'] = self.worker_airflow_home
     return env
示例#31
0
 def _get_security_context_val(self, scontext: str) -> Union[str, int]:
     val = conf.get(self.kubernetes_section, scontext)
     if not val:
         return ""
     else:
         return int(val)
示例#32
0
import logging

from airflow.configuration import conf
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.sequential_executor import SequentialExecutor

_EXECUTOR = conf.get('core', 'EXECUTOR')

if _EXECUTOR == 'LocalExecutor':
    DEFAULT_EXECUTOR = LocalExecutor()
elif _EXECUTOR == 'CeleryExecutor':
    DEFAULT_EXECUTOR = CeleryExecutor()
elif _EXECUTOR == 'SequentialExecutor':
    DEFAULT_EXECUTOR = SequentialExecutor()
else:
    raise Exception("Executor {0} not supported.".format(_EXECUTOR))

logging.info("Using executor " + _EXECUTOR)
示例#33
0
"""Default configuration for the Airflow webserver"""
import os

from flask_appbuilder.security.manager import AUTH_DB

from airflow.configuration import conf

# from flask_appbuilder.security.manager import AUTH_LDAP
# from flask_appbuilder.security.manager import AUTH_OAUTH
# from flask_appbuilder.security.manager import AUTH_OID
# from flask_appbuilder.security.manager import AUTH_REMOTE_USER

basedir = os.path.abspath(os.path.dirname(__file__))

# The SQLAlchemy connection string.
SQLALCHEMY_DATABASE_URI = conf.get('core', 'SQL_ALCHEMY_CONN')

# Flask-WTF flag for CSRF
WTF_CSRF_ENABLED = True

# ----------------------------------------------------
# AUTHENTICATION CONFIG
# ----------------------------------------------------
# For details on how to set up each of the following authentication, see
# http://flask-appbuilder.readthedocs.io/en/latest/security.html# authentication-methods
# for details.

# The authentication type
# AUTH_OID : Is for OpenID
# AUTH_DB : Is for database
# AUTH_LDAP : Is for LDAP
示例#34
0
    def __init__(
        self,
        dag_directory: Union[str, "pathlib.Path"],
        max_runs: int,
        processor_timeout: timedelta,
        signal_conn: MultiprocessingConnection,
        dag_ids: Optional[List[str]],
        pickle_dags: bool,
        async_mode: bool = True,
    ):
        super().__init__()
        self._file_paths: List[str] = []
        self._file_path_queue: List[str] = []
        self._dag_directory = dag_directory
        self._max_runs = max_runs
        self._signal_conn = signal_conn
        self._pickle_dags = pickle_dags
        self._dag_ids = dag_ids
        self._async_mode = async_mode
        self._parsing_start_time: Optional[int] = None

        # Set the signal conn in to non-blocking mode, so that attempting to
        # send when the buffer is full errors, rather than hangs for-ever
        # attempting to send (this is to avoid deadlocks!)
        #
        # Don't do this in sync_mode, as we _need_ the DagParsingStat sent to
        # continue the scheduler
        if self._async_mode:
            os.set_blocking(self._signal_conn.fileno(), False)

        self._parallelism = conf.getint('scheduler', 'parsing_processes')
        if conf.get('core', 'sql_alchemy_conn').startswith('sqlite') and self._parallelism > 1:
            self.log.warning(
                "Because we cannot use more than 1 thread (parsing_processes = "
                "%d) when using sqlite. So we set parallelism to 1.",
                self._parallelism,
            )
            self._parallelism = 1

        # Parse and schedule each file no faster than this interval.
        self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval')
        # How often to print out DAG file processing stats to the log. Default to
        # 30 seconds.
        self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval')
        # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
        self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

        # Map from file path to the processor
        self._processors: Dict[str, DagFileProcessorProcess] = {}

        self._num_run = 0

        # Map from file path to stats about the file
        self._file_stats: Dict[str, DagFileStat] = {}

        self._last_zombie_query_time = None
        # Last time that the DAG dir was traversed to look for files
        self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0))
        # Last time stats were printed
        self.last_stat_print_time = 0
        # TODO: Remove magic number
        self._zombie_query_interval = 10
        # How long to wait before timing out a process to parse a DAG file
        self._processor_timeout = processor_timeout

        # How often to scan the DAGs directory for new files. Default to 5 minutes.
        self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval')

        # Mapping file name and callbacks requests
        self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list)

        self._log = logging.getLogger('airflow.processor_manager')

        self.waitables: Dict[Any, Union[MultiprocessingConnection, DagFileProcessorProcess]] = {
            self._signal_conn: self._signal_conn,
        }
示例#35
0
# under the License.
"""Default celery configuration."""
import logging
import ssl

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException


def _broker_supports_visibility_timeout(url):
    return url.startswith("redis://") or url.startswith("sqs://")


log = logging.getLogger(__name__)

broker_url = conf.get('celery', 'BROKER_URL')

broker_transport_options = conf.getsection(
    'celery_broker_transport_options'
)
if 'visibility_timeout' not in broker_transport_options:
    if _broker_supports_visibility_timeout(broker_url):
        broker_transport_options['visibility_timeout'] = 21600

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer': 'json',
    'worker_prefetch_multiplier': 1,
    'task_acks_late': True,
    'task_default_queue': conf.get('celery', 'DEFAULT_QUEUE'),
    'task_default_exchange': conf.get('celery', 'DEFAULT_QUEUE'),
示例#36
0
# under the License.
"""Default celery configuration."""
import logging
import ssl

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException


def _broker_supports_visibility_timeout(url):
    return url.startswith("redis://") or url.startswith("sqs://")


log = logging.getLogger(__name__)

broker_url = conf.get('celery', 'BROKER_URL')

broker_transport_options = conf.getsection(
    'celery_broker_transport_options') or {}
if 'visibility_timeout' not in broker_transport_options:
    if _broker_supports_visibility_timeout(broker_url):
        broker_transport_options['visibility_timeout'] = 21600

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer':
    'json',
    'worker_prefetch_multiplier':
    conf.getint('celery', 'worker_prefetch_multiplier'),
    'task_acks_late':
    True,
示例#37
0
    def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_processor(
            self):
        """
        Check that the same set of failure callback with zombies are passed to the dag
        file processors until the next zombie detection logic is invoked.
        """
        test_dag_path = os.path.join(TEST_DAG_FOLDER,
                                     'test_example_bash_operator.py')
        with conf_vars({
            ('scheduler', 'parsing_processes'): '1',
            ('core', 'load_examples'): 'False'
        }):
            dagbag = DagBag(test_dag_path, read_dags_from_db=False)
            with create_session() as session:
                session.query(LJ).delete()
                dag = dagbag.get_dag('test_example_bash_operator')
                dag.sync_to_db()
                task = dag.get_task(task_id='run_this_last')

                ti = TI(task, DEFAULT_DATE, State.RUNNING)
                local_job = LJ(ti)
                local_job.state = State.SHUTDOWN
                session.add(local_job)
                session.commit()

                # TODO: If there was an actual Relationship between TI and Job
                # we wouldn't need this extra commit
                session.add(ti)
                ti.job_id = local_job.id
                session.commit()

                expected_failure_callback_requests = [
                    TaskCallbackRequest(
                        full_filepath=dag.full_filepath,
                        simple_task_instance=SimpleTaskInstance(ti),
                        msg="Message",
                    )
                ]

            test_dag_path = os.path.join(TEST_DAG_FOLDER,
                                         'test_example_bash_operator.py')

            child_pipe, parent_pipe = multiprocessing.Pipe()
            async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')

            fake_processors = []

            def fake_processor_factory(*args, **kwargs):
                nonlocal fake_processors
                processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(
                    *args, **kwargs)
                fake_processors.append(processor)
                return processor

            manager = DagFileProcessorManager(
                dag_directory=test_dag_path,
                max_runs=1,
                processor_factory=fake_processor_factory,
                processor_timeout=timedelta.max,
                signal_conn=child_pipe,
                dag_ids=[],
                pickle_dags=False,
                async_mode=async_mode,
            )

            self.run_processor_manager_one_loop(manager, parent_pipe)

            if async_mode:
                # Once for initial parse, and then again for the add_callback_to_queue
                assert len(fake_processors) == 2
                assert fake_processors[0]._file_path == test_dag_path
                assert fake_processors[0]._callback_requests == []
            else:
                assert len(fake_processors) == 1

            assert fake_processors[-1]._file_path == test_dag_path
            callback_requests = fake_processors[-1]._callback_requests
            assert {
                zombie.simple_task_instance.key
                for zombie in expected_failure_callback_requests
            } == {
                result.simple_task_instance.key
                for result in callback_requests
            }

            child_pipe.close()
            parent_pipe.close()
示例#38
0
def create_app(config=None, testing=False, app_name="Airflow"):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    session_lifetime_days = conf.getint('webserver',
                                        'SESSION_LIFETIME_DAYS',
                                        fallback=30)
    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        days=session_lifetime_days)

    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = app_name
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    flask_app.config['SESSION_COOKIE_SAMESITE'] = conf.get(
        'webserver', 'COOKIE_SAMESITE')

    if config:
        flask_app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    csrf.init_app(flask_app)

    init_wsgi_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    Cache(app=flask_app,
          config={
              'CACHE_TYPE': 'filesystem',
              'CACHE_DIR': '/tmp'
          })

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_logout_timeout(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)

    return flask_app
示例#39
0
    def on_kill(self) -> None:
        """Kill Spark submit command"""
        self.log.debug("Kill Command is being called")

        if self._should_track_driver_status:
            if self._driver_id:
                self.log.info('Killing driver %s on cluster', self._driver_id)

                kill_cmd = self._build_spark_driver_kill_command()
                with subprocess.Popen(kill_cmd,
                                      stdout=subprocess.PIPE,
                                      stderr=subprocess.PIPE) as driver_kill:
                    self.log.info(
                        "Spark driver %s killed with return code: %s",
                        self._driver_id, driver_kill.wait())

        if self._submit_sp and self._submit_sp.poll() is None:
            self.log.info('Sending kill signal to %s',
                          self._connection['spark_binary'])
            self._submit_sp.kill()

            if self._yarn_application_id:
                kill_cmd = f"yarn application -kill {self._yarn_application_id}".split(
                )
                env = {**os.environ, **(self._env or {})}
                if self._keytab is not None and self._principal is not None:
                    # we are ignoring renewal failures from renew_from_kt
                    # here as the failure could just be due to a non-renewable ticket,
                    # we still attempt to kill the yarn application
                    renew_from_kt(self._principal,
                                  self._keytab,
                                  exit_on_fail=False)
                    env = os.environ.copy()
                    env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache')

                with subprocess.Popen(kill_cmd,
                                      env=env,
                                      stdout=subprocess.PIPE,
                                      stderr=subprocess.PIPE) as yarn_kill:
                    self.log.info("YARN app killed with return code: %s",
                                  yarn_kill.wait())

            if self._kubernetes_driver_pod:
                self.log.info('Killing pod %s on Kubernetes',
                              self._kubernetes_driver_pod)

                # Currently only instantiate Kubernetes client for killing a spark pod.
                try:
                    import kubernetes

                    client = kube_client.get_kube_client()
                    api_response = client.delete_namespaced_pod(
                        self._kubernetes_driver_pod,
                        self._connection['namespace'],
                        body=kubernetes.client.V1DeleteOptions(),
                        pretty=True,
                    )

                    self.log.info("Spark on K8s killed with response: %s",
                                  api_response)

                except kube_client.ApiException:
                    self.log.exception(
                        "Exception when attempting to kill Spark on K8s")
示例#40
0
    ("file",),
    help="Import variables from JSON file")
ARG_VAR_EXPORT = Arg(
    ("file",),
    help="Export all variables to JSON file")

# kerberos
ARG_PRINCIPAL = Arg(
    ("principal",),
    help="kerberos principal",
    nargs='?')
ARG_KEYTAB = Arg(
    ("-k", "--keytab"),
    help="keytab",
    nargs='?',
    default=conf.get('kerberos', 'keytab'))
# run
# TODO(aoen): "force" is a poor choice of name here since it implies it overrides
# all dependencies (not just past success), e.g. the ignore_depends_on_past
# dependency. This flag should be deprecated and renamed to 'ignore_ti_state' and
# the "ignore_all_dependencies" command should be called the"force" command
# instead.
ARG_INTERACTIVE = Arg(
    ('-N', '--interactive'),
    help='Do not capture standard output and error streams '
         '(useful for interactive debugging)',
    action='store_true')
ARG_FORCE = Arg(
    ("-f", "--force"),
    help="Ignore previous task instance state, rerun regardless if task already succeeded/failed",
    action="store_true")
示例#41
0
import sys
from time import sleep

from sqlalchemy import Column, Integer, String, DateTime, func, Index
from sqlalchemy.orm.session import make_transient

from airflow import executors, models, settings, utils
from airflow.configuration import conf
from airflow.utils import AirflowException, State

Base = models.Base
ID_LEN = models.ID_LEN

# Setting up a statsd client if needed
statsd = None
if conf.get('scheduler', 'statsd_on'):
    from statsd import StatsClient
    statsd = StatsClient(host=conf.get('scheduler', 'statsd_host'),
                         port=conf.getint('scheduler', 'statsd_port'),
                         prefix=conf.get('scheduler', 'statsd_prefix'))


class BaseJob(Base):
    """
    Abstract class to be derived for jobs. Jobs are processing items with state
    and duration that aren't task instances. For instance a BackfillJob is
    a collection of task instance runs, but should have it's own state, start
    and end time.
    """

    __tablename__ = "job"
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.airflow_home = configuration.get(self.core_section,
                                              'airflow_home')
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section,
                                                'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(self.kubernetes_section,
                                                      'worker_container_tag')
        self.worker_dags_folder = configuration.get(self.kubernetes_section,
                                                    'worker_dags_folder')
        self.kube_image = '{}:{}'.format(self.worker_container_repository,
                                         self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy")
        self.kube_node_selectors = configuration_dict.get(
            'kubernetes_node_selectors', {})
        self.delete_worker_pods = conf.getboolean(self.kubernetes_section,
                                                  'delete_worker_pods')

        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section,
                                           'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section,
                                             'dags_in_image')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')

        # Optionally a user may supply a `git_user` and `git_password` for private
        # repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section,
                                          'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section,
                                          'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(self.kubernetes_section,
                                            'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(self.kubernetes_section,
                                            'logs_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section,
                                                 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section,
                                           'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(self.kubernetes_section,
                                               'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section,
                                          'airflow_configmap')

        self._validate()
示例#43
0
def worker(args):
    """Starts Airflow Celery worker"""
    env = os.environ.copy()
    env['AIRFLOW_HOME'] = settings.AIRFLOW_HOME

    if not settings.validate_session():
        print("Worker exiting... database connection precheck failed! ")
        sys.exit(1)

    # Celery worker
    from airflow.executors.celery_executor import app as celery_app
    from celery.bin import worker  # pylint: disable=redefined-outer-name

    autoscale = args.autoscale
    if autoscale is None and conf.has_option("celery", "worker_autoscale"):
        autoscale = conf.get("celery", "worker_autoscale")
    worker = worker.worker(app=celery_app)  # pylint: disable=redefined-outer-name
    options = {
        'optimization': 'fair',
        'O': 'fair',
        'queues': args.queues,
        'concurrency': args.concurrency,
        'autoscale': autoscale,
        'hostname': args.celery_hostname,
        'loglevel': conf.get('core', 'LOGGING_LEVEL'),
    }

    if conf.has_option("celery", "pool"):
        options["pool"] = conf.get("celery", "pool")

    if args.daemon:
        pid, stdout, stderr, log_file = setup_locations(
            "worker", args.pid, args.stdout, args.stderr, args.log_file)
        handle = setup_logging(log_file)
        stdout = open(stdout, 'w+')
        stderr = open(stderr, 'w+')

        ctx = daemon.DaemonContext(
            pidfile=TimeoutPIDLockFile(pid, -1),
            files_preserve=[handle],
            stdout=stdout,
            stderr=stderr,
        )
        with ctx:
            sub_proc = subprocess.Popen(['airflow', 'serve_logs'],
                                        env=env,
                                        close_fds=True)
            worker.run(**options)
            sub_proc.kill()

        stdout.close()
        stderr.close()
    else:
        signal.signal(signal.SIGINT, sigint_handler)
        signal.signal(signal.SIGTERM, sigint_handler)

        sub_proc = subprocess.Popen(['airflow', 'serve_logs'],
                                    env=env,
                                    close_fds=True)

        worker.run(**options)
        sub_proc.kill()
示例#44
0
import os
from airflow import DAG  # This module must be imported for airflow to see DAGs
from airflow.configuration import conf

from ewah.dag_factories import dags_from_yml_file

file_name = "{0}{1}{2}".format(
    os.environ.get("AIRFLOW__CORE__DAGS_FOLDER",
                   conf.get("core", "dags_folder")),
    os.sep,
    "dags.yml",
)
if os.path.isfile(file_name):
    dags = dags_from_yml_file(file_name, True, True)
    for dag in dags:  # Must add the individual DAGs to the global namespace
        globals()[dag._dag_id] = dag
else:
    raise Exception("Not a file: {0}".format(file_name))
    def _read(self, ti, try_number, metadata=None):  # pylint: disable=unused-argument
        """
        Template method that contains custom logic of reading
        logs given the try_number.

        :param ti: task instance record
        :param try_number: current try_number to read log from
        :param metadata: log metadata,
                         can be used for steaming log reading and auto-tailing.
        :return: log message as a string and metadata.
        """
        # Task instance here might be different from task instance when
        # initializing the handler. Thus explicitly getting log location
        # is needed to get correct log path.
        log_relative_path = self._render_filename(ti, try_number)
        location = os.path.join(self.local_base, log_relative_path)

        log = ""

        if os.path.exists(location):
            try:
                with open(location) as file:
                    log += f"*** Reading local file: {location}\n"
                    log += "".join(file.readlines())
            except Exception as e:  # pylint: disable=broad-except
                log = f"*** Failed to load local log file: {location}\n"
                log += f"*** {str(e)}\n"
        elif conf.get('core', 'executor') == 'KubernetesExecutor':  # pylint: disable=too-many-nested-blocks
            try:
                from airflow.kubernetes.kube_client import get_kube_client

                kube_client = get_kube_client()

                if len(ti.hostname) >= 63:
                    # Kubernetes takes the pod name and truncates it for the hostname. This truncated hostname
                    # is returned for the fqdn to comply with the 63 character limit imposed by DNS standards
                    # on any label of a FQDN.
                    pod_list = kube_client.list_namespaced_pod(
                        conf.get('kubernetes', 'namespace'))
                    matches = [
                        pod.metadata.name for pod in pod_list.items
                        if pod.metadata.name.startswith(ti.hostname)
                    ]
                    if len(matches) == 1:
                        if len(matches[0]) > len(ti.hostname):
                            ti.hostname = matches[0]

                log += '*** Trying to get logs (last 100 lines) from worker pod {} ***\n\n'.format(
                    ti.hostname)

                res = kube_client.read_namespaced_pod_log(
                    name=ti.hostname,
                    namespace=conf.get('kubernetes', 'namespace'),
                    container='base',
                    follow=False,
                    tail_lines=100,
                    _preload_content=False,
                )

                for line in res:
                    log += line.decode()

            except Exception as f:  # pylint: disable=broad-except
                log += f'*** Unable to fetch logs from worker pod {ti.hostname} ***\n{str(f)}\n\n'
        else:
            url = os.path.join(
                "http://{ti.hostname}:{worker_log_server_port}/log",
                log_relative_path).format(ti=ti,
                                          worker_log_server_port=conf.get(
                                              'celery',
                                              'WORKER_LOG_SERVER_PORT'))
            log += f"*** Log file does not exist: {location}\n"
            log += f"*** Fetching from: {url}\n"
            try:
                timeout = None  # No timeout
                try:
                    timeout = conf.getint('webserver', 'log_fetch_timeout_sec')
                except (AirflowConfigException, ValueError):
                    pass

                response = requests.get(url, timeout=timeout)
                response.encoding = "utf-8"

                # Check if the resource was properly fetched
                response.raise_for_status()

                log += '\n' + response.text
            except Exception as e:  # pylint: disable=broad-except
                log += f"*** Failed to fetch log file from worker. {str(e)}\n"

        return log, {'end_of_log': True}
示例#46
0
    def __init__(
        self,
        task_id: str,
        owner: str = conf.get('operators', 'DEFAULT_OWNER'),
        email: Optional[Union[str, Iterable[str]]] = None,
        email_on_retry: bool = True,
        email_on_failure: bool = True,
        retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0),
        retry_delay: timedelta = timedelta(seconds=300),
        retry_exponential_backoff: bool = False,
        max_retry_delay: Optional[datetime] = None,
        start_date: Optional[datetime] = None,
        end_date: Optional[datetime] = None,
        depends_on_past: bool = False,
        wait_for_downstream: bool = False,
        dag=None,
        params: Optional[Dict] = None,
        default_args: Optional[Dict] = None,  # pylint: disable=unused-argument
        priority_weight: int = 1,
        weight_rule: str = WeightRule.DOWNSTREAM,
        queue: str = conf.get('celery', 'default_queue'),
        pool: str = Pool.DEFAULT_POOL_NAME,
        sla: Optional[timedelta] = None,
        execution_timeout: Optional[timedelta] = None,
        on_failure_callback: Optional[Callable] = None,
        on_success_callback: Optional[Callable] = None,
        on_retry_callback: Optional[Callable] = None,
        trigger_rule: str = TriggerRule.ALL_SUCCESS,
        resources: Optional[Dict] = None,
        run_as_user: Optional[str] = None,
        task_concurrency: Optional[int] = None,
        executor_config: Optional[Dict] = None,
        do_xcom_push: bool = True,
        inlets: Optional[Dict] = None,
        outlets: Optional[Dict] = None,
        *args,
        **kwargs
    ):
        from airflow.models.dag import DagContext
        super().__init__()
        if args or kwargs:
            if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                raise AirflowException(
                    "Invalid arguments were passed to {c} (task_id: {t}). Invalid "
                    "arguments were:\n*args: {a}\n**kwargs: {k}".format(
                        c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                )
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'future. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(
                    c=self.__class__.__name__, a=args, k=kwargs, t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3
            )
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_triggers=TriggerRule.all_triggers(),
                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            # noinspection PyTypeChecker
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'."
                .format(all_weight_rules=WeightRule.all_weight_rules,
                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
        self.weight_rule = weight_rule
        self.resources: Optional[Resources] = Resources(**resources) if resources else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids: Set[str] = set()
        self._downstream_task_ids: Set[str] = set()
        self._dag = None

        self.dag = dag or DagContext.get_current_dag()

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets: List[DataSet] = []
        self.outlets: List[DataSet] = []
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets: Dict[str, Iterable] = {
            "datasets": [],
        }

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
示例#47
0
    def prepare_file_path_queue(self):
        """Generate more file paths to process. Result are saved in _file_path_queue."""
        self._parsing_start_time = time.perf_counter()
        # If the file path is already being processed, or if a file was
        # processed recently, wait until the next batch
        file_paths_in_progress = self._processors.keys()
        now = timezone.utcnow()

        # Sort the file paths by the parsing order mode
        list_mode = conf.get("scheduler", "file_parsing_sort_mode")

        files_with_mtime = {}
        file_paths = []
        is_mtime_mode = list_mode == "modified_time"

        file_paths_recently_processed = []
        for file_path in self._file_paths:

            if is_mtime_mode:
                try:
                    files_with_mtime[file_path] = os.path.getmtime(file_path)
                except FileNotFoundError:
                    self.log.warning("Skipping processing of missing file: %s", file_path)
                    continue
                file_modified_time = timezone.make_aware(datetime.fromtimestamp(files_with_mtime[file_path]))
            else:
                file_paths.append(file_path)
                file_modified_time = None

            # Find file paths that were recently processed to exclude them
            # from being added to file_path_queue
            # unless they were modified recently and parsing mode is "modified_time"
            # in which case we don't honor "self._file_process_interval" (min_file_process_interval)
            last_finish_time = self.get_last_finish_time(file_path)
            if (
                last_finish_time is not None
                and (now - last_finish_time).total_seconds() < self._file_process_interval
                and not (is_mtime_mode and file_modified_time and (file_modified_time > last_finish_time))
            ):
                file_paths_recently_processed.append(file_path)

        # Sort file paths via last modified time
        if is_mtime_mode:
            file_paths = sorted(files_with_mtime, key=files_with_mtime.get, reverse=True)
        elif list_mode == "alphabetical":
            file_paths = sorted(file_paths)
        elif list_mode == "random_seeded_by_host":
            # Shuffle the list seeded by hostname so multiple schedulers can work on different
            # set of files. Since we set the seed, the sort order will remain same per host
            random.Random(get_hostname()).shuffle(file_paths)

        files_paths_at_run_limit = [
            file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs
        ]

        file_paths_to_exclude = set(file_paths_in_progress).union(
            file_paths_recently_processed, files_paths_at_run_limit
        )

        # Do not convert the following list to set as set does not preserve the order
        # and we need to maintain the order of file_paths for `[scheduler] file_parsing_sort_mode`
        files_paths_to_queue = [
            file_path for file_path in file_paths if file_path not in file_paths_to_exclude
        ]

        for file_path, processor in self._processors.items():
            self.log.debug(
                "File path %s is still being processed (started: %s)",
                processor.file_path,
                processor.start_time.isoformat(),
            )

        self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue))

        for file_path in files_paths_to_queue:
            if file_path not in self._file_stats:
                self._file_stats[file_path] = DagFileStat(
                    num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0
                )

        self._file_path_queue.extend(files_paths_to_queue)
示例#48
0
    def is_scheduler_running(self, host):
        self.logger.info("Starting to Check if Scheduler on host '" +
                         str(host) + "' is running...")

        process_check_command = "ps -eaf"
        grep_command = "grep 'airflow scheduler' | grep -v grep || true"
        grep_command_no_quotes = grep_command.replace("'", "")
        full_status_check_command = process_check_command + " | " + grep_command  # ps -eaf | grep 'airflow scheduler' | grep -v grep || true
        is_running = False
        is_successful, output = self.command_runner.run_command(
            host, full_status_check_command)
        self.LATEST_FAILED_STATUS_MESSAGE = output
        if is_successful:
            active_list = []
            for line in output:
                if line.strip(
                ) != "" and process_check_command not in line and grep_command not in line and grep_command_no_quotes not in line and full_status_check_command not in line:
                    active_list.append(line)

            active_list_length = len(list(filter(None, active_list)))

            # todo: If there's more then one scheduler running this should kill off the other schedulers. MIGHT ALREADY BE HANDLED. DOUBLE CHECK.

            is_running = active_list_length > 0
            if is_running:
                af_health_url = str(conf.get("webserver",
                                             "base_url")) + "/health"
                self.logger.info("Airflow Health URL: " + str(af_health_url))
                try:
                    response = requests.request("GET",
                                                af_health_url,
                                                verify=False)
                    json_data = response.json()

                    scheduler_status = json_data["scheduler"]["status"]
                    if (str.lower(scheduler_status) == "healthy"):
                        self.logger.info(
                            "According to the webserver, scheduler_status: " +
                            str(scheduler_status))
                        is_running = True

                    else:
                        is_running = False
                        self.logger.info(
                            "Finished Checking if Scheduler on host '" +
                            str(host) +
                            "' is running. Seems the airflow scheduler is hung. According to the webserver, scheduler_status: "
                            + str(scheduler_status))

                        # Killing hung processes
                        self.shutdown_scheduler(host)
                except Exception as e:
                    self.logger.warn(e)
                    self.logger.warn(
                        "Failed to do a GET call on the Airflow webserver")
        else:
            self.logger.critical("is_scheduler_running check failed")

        self.logger.info("Finished Checking if Scheduler on host '" +
                         str(host) + "' is running. is_running: " +
                         str(is_running))

        return is_running
 def test_1_9_config(self):
     from airflow.logging_config import configure_logging
     with conf_vars({('logging', 'task_log_reader'): 'file.task'}):
         with self.assertWarnsRegex(DeprecationWarning, r'file.task'):
             configure_logging()
         self.assertEqual(conf.get('logging', 'task_log_reader'), 'task')
示例#50
0
 def test_case_sensitivity(self):
     # section and key are case insensitive for get method
     # note: this is not the case for as_dict method
     self.assertEqual(conf.get("core", "percent"), "with%inside")
     self.assertEqual(conf.get("core", "PERCENT"), "with%inside")
     self.assertEqual(conf.get("CORE", "PERCENT"), "with%inside")
示例#51
0
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

TaskStateChangeCallback = Callable[[Context], None]

if TYPE_CHECKING:
    import jinja2  # Slow import.

    from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
    from airflow.models.dag import DAG
    from airflow.models.operator import Operator
    from airflow.models.taskinstance import TaskInstance

DEFAULT_OWNER: str = conf.get("operators", "default_owner")
DEFAULT_POOL_SLOTS: int = 1
DEFAULT_PRIORITY_WEIGHT: int = 1
DEFAULT_QUEUE: str = conf.get("operators", "default_queue")
DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
    conf.get("core",
             "default_task_weight_rule",
             fallback=WeightRule.DOWNSTREAM))
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS


class AbstractOperator(LoggingMixin, DAGNode):
    """Common implementation for operators, including unmapped and mapped.
示例#52
0
    def __init__(self,
                 dag_directory: str,
                 max_runs: int,
                 processor_factory: Callable[
                     [str, List[FailureCallbackRequest]],
                     AbstractDagFileProcessorProcess
                 ],
                 processor_timeout: timedelta,
                 signal_conn: MultiprocessingConnection,
                 dag_ids: Optional[List[str]],
                 pickle_dags: bool,
                 async_mode: bool = True):
        super().__init__()
        self._file_paths: List[str] = []
        self._file_path_queue: List[str] = []
        self._dag_directory = dag_directory
        self._max_runs = max_runs
        self._processor_factory = processor_factory
        self._signal_conn = signal_conn
        self._pickle_dags = pickle_dags
        self._dag_ids = dag_ids
        self._async_mode = async_mode
        self._parsing_start_time: Optional[datetime] = None

        self._parallelism = conf.getint('scheduler', 'max_threads')
        if 'sqlite' in conf.get('core', 'sql_alchemy_conn') and self._parallelism > 1:
            self.log.warning(
                "Because we cannot use more than 1 thread (max_threads = "
                "%d ) when using sqlite. So we set parallelism to 1.", self._parallelism
            )
            self._parallelism = 1

        # Parse and schedule each file no faster than this interval.
        self._file_process_interval = conf.getint('scheduler',
                                                  'min_file_process_interval')
        # How often to print out DAG file processing stats to the log. Default to
        # 30 seconds.
        self.print_stats_interval = conf.getint('scheduler',
                                                'print_stats_interval')
        # How many seconds do we wait for tasks to heartbeat before mark them as zombies.
        self._zombie_threshold_secs = (
            conf.getint('scheduler', 'scheduler_zombie_task_threshold'))

        # Should store dag file source in a database?
        self.store_dag_code = STORE_DAG_CODE
        # Map from file path to the processor
        self._processors: Dict[str, AbstractDagFileProcessorProcess] = {}

        self._num_run = 0

        # Map from file path to stats about the file
        self._file_stats: Dict[str, DagFileStat] = {}

        self._last_zombie_query_time = None
        # Last time that the DAG dir was traversed to look for files
        self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0))
        # Last time stats were printed
        self.last_stat_print_time = timezone.datetime(2000, 1, 1)
        # TODO: Remove magic number
        self._zombie_query_interval = 10
        # How long to wait before timing out a process to parse a DAG file
        self._processor_timeout = processor_timeout

        # How often to scan the DAGs directory for new files. Default to 5 minutes.
        self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval')

        # Mapping file name and callbacks requests
        self._callback_to_execute: Dict[str, List[FailureCallbackRequest]] = defaultdict(list)

        self._log = logging.getLogger('airflow.processor_manager')

        self.waitables = {self._signal_conn: self._signal_conn}
# under the License.
"""Default celery configuration."""
import logging
import ssl

from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException


def _broker_supports_visibility_timeout(url):
    return url.startswith("redis://") or url.startswith("sqs://")


log = logging.getLogger(__name__)

broker_url = conf.get('celery', 'BROKER_URL')

broker_transport_options = conf.getsection(
    'celery_broker_transport_options') or {}
if 'visibility_timeout' not in broker_transport_options:
    if _broker_supports_visibility_timeout(broker_url):
        broker_transport_options['visibility_timeout'] = 21600

DEFAULT_CELERY_CONFIG = {
    'accept_content': ['json'],
    'event_serializer':
    'json',
    'worker_prefetch_multiplier':
    conf.getint('celery', 'worker_prefetch_multiplier', fallback=1),
    'task_acks_late':
    True,
示例#54
0
def make_emitter_hook() -> "DatahubGenericHook":
    # This is necessary to avoid issues with circular imports.
    from datahub.integrations.airflow.hooks import DatahubGenericHook

    _datahub_conn_id = conf.get("lineage", "datahub_conn_id")
    return DatahubGenericHook(_datahub_conn_id)
示例#55
0
    def start(self):
        self.task_queue = Queue()
        self.result_queue = Queue()
        framework = mesos_pb2.FrameworkInfo()
        framework.user = ''

        if not conf.get('mesos', 'MASTER'):
            logging.error("Expecting mesos master URL for mesos executor")
            raise AirflowException(
                "mesos.master not provided for mesos executor")

        master = conf.get('mesos', 'MASTER')

        if not conf.get('mesos', 'FRAMEWORK_NAME'):
            framework.name = 'Airflow'
        else:
            framework.name = conf.get('mesos', 'FRAMEWORK_NAME')

        if not conf.get('mesos', 'TASK_CPU'):
            task_cpu = 1
        else:
            task_cpu = conf.getint('mesos', 'TASK_CPU')

        if not conf.get('mesos', 'TASK_MEMORY'):
            task_memory = 256
        else:
            task_memory = conf.getint('mesos', 'TASK_MEMORY')

        if (conf.getboolean('mesos', 'CHECKPOINT')):
            framework.checkpoint = True
        else:
            framework.checkpoint = False

        logging.info(
            'MesosFramework master : %s, name : %s, cpu : %s, mem : %s, checkpoint : %s',
            master, framework.name, str(task_cpu), str(task_memory),
            str(framework.checkpoint))

        implicit_acknowledgements = 1

        if (conf.getboolean('mesos', 'AUTHENTICATE')):
            if not conf.get('mesos', 'DEFAULT_PRINCIPAL'):
                logging.error(
                    "Expecting authentication principal in the environment")
                raise AirflowException(
                    "mesos.default_principal not provided in authenticated mode"
                )
            if not conf.get('mesos', 'DEFAULT_SECRET'):
                logging.error(
                    "Expecting authentication secret in the environment")
                raise AirflowException(
                    "mesos.default_secret not provided in authenticated mode")

            credential = mesos_pb2.Credential()
            credential.principal = conf.get('mesos', 'DEFAULT_PRINCIPAL')
            credential.secret = conf.get('mesos', 'DEFAULT_SECRET')

            framework.principal = credential.principal

            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue,
                                      task_cpu, task_memory), framework,
                master, implicit_acknowledgements, credential)
        else:
            framework.principal = 'Airflow'
            driver = mesos.native.MesosSchedulerDriver(
                AirflowMesosScheduler(self.task_queue, self.result_queue,
                                      task_cpu, task_memory), framework,
                master, implicit_acknowledgements)

        self.mesos_driver = driver
        self.mesos_driver.start()
示例#56
0
CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'

CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'

OPERATION_TIMEOUT = conf.getfloat('celery', 'operation_timeout', fallback=1.0)
'''
To start the celery worker, run the command:
airflow celery worker
'''

if conf.has_option('celery', 'celery_config_options'):
    celery_configuration = conf.getimport('celery', 'celery_config_options')
else:
    celery_configuration = DEFAULT_CELERY_CONFIG

app = Celery(conf.get('celery', 'CELERY_APP_NAME'),
             config_source=celery_configuration)


@app.task
def execute_command(command_to_exec: CommandType) -> None:
    """Executes command."""
    BaseExecutor.validate_command(command_to_exec)
    log.info("Executing command in Celery: %s", command_to_exec)

    if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
        _execute_in_subprocess(command_to_exec)
    else:
        _execute_in_fork(command_to_exec)

示例#57
0
    def __init__(
            self,
            task_id,  # type: str
            owner=conf.get('operators', 'DEFAULT_OWNER'),  # type: str
            email=None,  # type: Optional[Union[str, Iterable[str]]]
            email_on_retry=True,  # type: bool
            email_on_failure=True,  # type: bool
            retries=conf.getint('core', 'default_task_retries',
                                fallback=0),  # type: int
            retry_delay=timedelta(seconds=300),  # type: timedelta
            retry_exponential_backoff=False,  # type: bool
            max_retry_delay=None,  # type: Optional[datetime]
            start_date=None,  # type: Optional[datetime]
            end_date=None,  # type: Optional[datetime]
            schedule_interval=None,  # not hooked as of now
            depends_on_past=False,  # type: bool
            wait_for_downstream=False,  # type: bool
            dag=None,  # type: Optional[DAG]
            params=None,  # type: Optional[Dict]
            default_args=None,  # type: Optional[Dict]
            priority_weight=1,  # type: int
            weight_rule=WeightRule.DOWNSTREAM,  # type: str
            queue=conf.get('celery', 'default_queue'),  # type: str
            pool=Pool.DEFAULT_POOL_NAME,  # type: str
            pool_slots=1,  # type: int
            sla=None,  # type: Optional[timedelta]
            execution_timeout=None,  # type: Optional[timedelta]
            on_failure_callback=None,  # type: Optional[Callable]
            on_success_callback=None,  # type: Optional[Callable]
            on_retry_callback=None,  # type: Optional[Callable]
            trigger_rule=TriggerRule.ALL_SUCCESS,  # type: str
            resources=None,  # type: Optional[Dict]
            run_as_user=None,  # type: Optional[str]
            task_concurrency=None,  # type: Optional[int]
            executor_config=None,  # type: Optional[Dict]
            do_xcom_push=True,  # type: bool
            inlets=None,  # type: Optional[Dict]
            outlets=None,  # type: Optional[Dict]
            *args,
            **kwargs):
        super().__init__()

        if args or kwargs:
            # TODO remove *args and **kwargs in Airflow 2.0
            warnings.warn(
                'Invalid arguments were passed to {c} (task_id: {t}). '
                'Support for passing such arguments will be dropped in '
                'Airflow 2.0. Invalid arguments were:'
                '\n*args: {a}\n**kwargs: {k}'.format(c=self.__class__.__name__,
                                                     a=args,
                                                     k=kwargs,
                                                     t=task_id),
                category=PendingDeprecationWarning,
                stacklevel=3)
        validate_key(task_id)
        self.task_id = task_id
        self.owner = owner
        self.email = email
        self.email_on_retry = email_on_retry
        self.email_on_failure = email_on_failure

        self.start_date = start_date
        if start_date and not isinstance(start_date, datetime):
            self.log.warning("start_date for %s isn't datetime.datetime", self)
        elif start_date:
            self.start_date = timezone.convert_to_utc(start_date)

        self.end_date = end_date
        if end_date:
            self.end_date = timezone.convert_to_utc(end_date)

        if not TriggerRule.is_valid(trigger_rule):
            raise AirflowException(
                "The trigger_rule must be one of {all_triggers},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_triggers=TriggerRule.all_triggers(),
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=trigger_rule))

        self.trigger_rule = trigger_rule
        self.depends_on_past = depends_on_past
        self.wait_for_downstream = wait_for_downstream
        if wait_for_downstream:
            self.depends_on_past = True

        if schedule_interval:
            self.log.warning(
                "schedule_interval is used for %s, though it has "
                "been deprecated as a task parameter, you need to "
                "specify it as a DAG parameter instead", self)
        self._schedule_interval = schedule_interval
        self.retries = retries
        self.queue = queue
        self.pool = pool
        self.pool_slots = pool_slots
        if self.pool_slots < 1:
            raise AirflowException(
                "pool slots for %s in dag %s cannot be less than 1" %
                (self.task_id, self.dag_id))
        self.sla = sla
        self.execution_timeout = execution_timeout
        self.on_failure_callback = on_failure_callback
        self.on_success_callback = on_success_callback
        self.on_retry_callback = on_retry_callback

        if isinstance(retry_delay, timedelta):
            self.retry_delay = retry_delay
        else:
            self.log.debug("Retry_delay isn't timedelta object, assuming secs")
            self.retry_delay = timedelta(seconds=retry_delay)
        self.retry_exponential_backoff = retry_exponential_backoff
        self.max_retry_delay = max_retry_delay
        self.params = params or {}  # Available in templates!
        self.priority_weight = priority_weight
        if not WeightRule.is_valid(weight_rule):
            raise AirflowException(
                "The weight_rule must be one of {all_weight_rules},"
                "'{d}.{t}'; received '{tr}'.".format(
                    all_weight_rules=WeightRule.all_weight_rules,
                    d=dag.dag_id if dag else "",
                    t=task_id,
                    tr=weight_rule))
        self.weight_rule = weight_rule

        self.resources = Resources(
            **resources) if resources is not None else None
        self.run_as_user = run_as_user
        self.task_concurrency = task_concurrency
        self.executor_config = executor_config or {}
        self.do_xcom_push = do_xcom_push

        # Private attributes
        self._upstream_task_ids = set()  # type: Set[str]
        self._downstream_task_ids = set()  # type: Set[str]

        if not dag and settings.CONTEXT_MANAGER_DAG:
            dag = settings.CONTEXT_MANAGER_DAG
        if dag:
            self.dag = dag

        # subdag parameter is only set for SubDagOperator.
        # Setting it to None by default as other Operators do not have that field
        self.subdag = None  # type: Optional[DAG]

        self._log = logging.getLogger("airflow.task.operators")

        # lineage
        self.inlets = []  # type: Iterable[DataSet]
        self.outlets = []  # type: Iterable[DataSet]
        self.lineage_data = None

        self._inlets = {
            "auto": False,
            "task_ids": [],
            "datasets": [],
        }

        self._outlets = {
            "datasets": [],
        }  # type: Dict

        if inlets:
            self._inlets.update(inlets)

        if outlets:
            self._outlets.update(outlets)
示例#58
0
    def __init__(self):
        configuration_dict = configuration.as_dict(display_sensitive=True)
        self.core_configuration = configuration_dict['core']
        self.kube_secrets = configuration_dict.get('kubernetes_secrets', {})
        self.kube_env_vars = configuration_dict.get('kubernetes_environment_variables', {})
        self.env_from_configmap_ref = configuration.get(self.kubernetes_section,
                                                        'env_from_configmap_ref')
        self.env_from_secret_ref = configuration.get(self.kubernetes_section,
                                                     'env_from_secret_ref')
        self.airflow_home = settings.AIRFLOW_HOME
        self.dags_folder = configuration.get(self.core_section, 'dags_folder')
        self.parallelism = configuration.getint(self.core_section, 'PARALLELISM')
        self.worker_container_repository = configuration.get(
            self.kubernetes_section, 'worker_container_repository')
        self.worker_container_tag = configuration.get(
            self.kubernetes_section, 'worker_container_tag')
        self.kube_image = '{}:{}'.format(
            self.worker_container_repository, self.worker_container_tag)
        self.kube_image_pull_policy = configuration.get(
            self.kubernetes_section, "worker_container_image_pull_policy"
        )
        self.kube_node_selectors = configuration_dict.get('kubernetes_node_selectors', {})
        self.kube_annotations = configuration_dict.get('kubernetes_annotations', {})
        self.kube_labels = configuration_dict.get('kubernetes_labels', {})
        self.delete_worker_pods = conf.getboolean(
            self.kubernetes_section, 'delete_worker_pods')
        self.worker_pods_creation_batch_size = conf.getint(
            self.kubernetes_section, 'worker_pods_creation_batch_size')
        self.worker_service_account_name = conf.get(
            self.kubernetes_section, 'worker_service_account_name')
        self.image_pull_secrets = conf.get(self.kubernetes_section, 'image_pull_secrets')

        # NOTE: user can build the dags into the docker image directly,
        # this will set to True if so
        self.dags_in_image = conf.getboolean(self.kubernetes_section, 'dags_in_image')

        # Run as user for pod security context
        self.worker_run_as_user = self._get_security_context_val('run_as_user')
        self.worker_fs_group = self._get_security_context_val('fs_group')

        # NOTE: `git_repo` and `git_branch` must be specified together as a pair
        # The http URL of the git repository to clone from
        self.git_repo = conf.get(self.kubernetes_section, 'git_repo')
        # The branch of the repository to be checked out
        self.git_branch = conf.get(self.kubernetes_section, 'git_branch')
        # Optionally, the directory in the git repository containing the dags
        self.git_subpath = conf.get(self.kubernetes_section, 'git_subpath')
        # Optionally, the root directory for git operations
        self.git_sync_root = conf.get(self.kubernetes_section, 'git_sync_root')
        # Optionally, the name at which to publish the checked-out files under --root
        self.git_sync_dest = conf.get(self.kubernetes_section, 'git_sync_dest')
        # Optionally, if git_dags_folder_mount_point is set the worker will use
        # {git_dags_folder_mount_point}/{git_sync_dest}/{git_subpath} as dags_folder
        self.git_dags_folder_mount_point = conf.get(self.kubernetes_section,
                                                    'git_dags_folder_mount_point')

        # Optionally a user may supply a (`git_user` AND `git_password`) OR
        # (`git_ssh_key_secret_name` AND `git_ssh_key_secret_key`) for private repositories
        self.git_user = conf.get(self.kubernetes_section, 'git_user')
        self.git_password = conf.get(self.kubernetes_section, 'git_password')
        self.git_ssh_key_secret_name = conf.get(self.kubernetes_section, 'git_ssh_key_secret_name')
        self.git_ssh_known_hosts_configmap_name = conf.get(self.kubernetes_section,
                                                           'git_ssh_known_hosts_configmap_name')

        # NOTE: The user may optionally use a volume claim to mount a PV containing
        # DAGs directly
        self.dags_volume_claim = conf.get(self.kubernetes_section, 'dags_volume_claim')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.logs_volume_claim = conf.get(self.kubernetes_section, 'logs_volume_claim')

        # This prop may optionally be set for PV Claims and is used to locate DAGs
        # on a SubPath
        self.dags_volume_subpath = conf.get(
            self.kubernetes_section, 'dags_volume_subpath')

        # This prop may optionally be set for PV Claims and is used to locate logs
        # on a SubPath
        self.logs_volume_subpath = conf.get(
            self.kubernetes_section, 'logs_volume_subpath')

        # Optionally, hostPath volume containing DAGs
        self.dags_volume_host = conf.get(self.kubernetes_section, 'dags_volume_host')

        # Optionally, write logs to a hostPath Volume
        self.logs_volume_host = conf.get(self.kubernetes_section, 'logs_volume_host')

        # This prop may optionally be set for PV Claims and is used to write logs
        self.base_log_folder = configuration.get(self.core_section, 'base_log_folder')

        # The Kubernetes Namespace in which the Scheduler and Webserver reside. Note
        # that if your
        # cluster has RBAC enabled, your scheduler may need service account permissions to
        # create, watch, get, and delete pods in this namespace.
        self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
        # The Kubernetes Namespace in which pods will be created by the executor. Note
        # that if your
        # cluster has RBAC enabled, your workers may need service account permissions to
        # interact with cluster components.
        self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')
        # Task secrets managed by KubernetesExecutor.
        self.gcp_service_account_keys = conf.get(self.kubernetes_section,
                                                 'gcp_service_account_keys')

        # If the user is using the git-sync container to clone their repository via git,
        # allow them to specify repository, tag, and pod name for the init container.
        self.git_sync_container_repository = conf.get(
            self.kubernetes_section, 'git_sync_container_repository')

        self.git_sync_container_tag = conf.get(
            self.kubernetes_section, 'git_sync_container_tag')
        self.git_sync_container = '{}:{}'.format(
            self.git_sync_container_repository, self.git_sync_container_tag)

        self.git_sync_init_container_name = conf.get(
            self.kubernetes_section, 'git_sync_init_container_name')

        # The worker pod may optionally have a  valid Airflow config loaded via a
        # configmap
        self.airflow_configmap = conf.get(self.kubernetes_section, 'airflow_configmap')

        affinity_json = conf.get(self.kubernetes_section, 'affinity')
        if affinity_json:
            self.kube_affinity = json.loads(affinity_json)
        else:
            self.kube_affinity = None

        tolerations_json = conf.get(self.kubernetes_section, 'tolerations')
        if tolerations_json:
            self.kube_tolerations = json.loads(tolerations_json)
        else:
            self.kube_tolerations = None

        kube_client_request_args = conf.get(self.kubernetes_section, 'kube_client_request_args')
        if kube_client_request_args:
            self.kube_client_request_args = json.loads(kube_client_request_args)
            if self.kube_client_request_args['_request_timeout'] and \
                    isinstance(self.kube_client_request_args['_request_timeout'], list):
                self.kube_client_request_args['_request_timeout'] = \
                    tuple(self.kube_client_request_args['_request_timeout'])
        else:
            self.kube_client_request_args = {}
        self._validate()
示例#59
0
文件: db_utils.py 项目: kalebinn/dbnd
def airlow_sql_alchemy_conn():
    from airflow.configuration import conf as airflow_conf

    return airflow_conf.get("core", "sql_alchemy_conn")
示例#60
0
def create_app(config=None, testing=False):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        minutes=settings.get_session_lifetime_config())
    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = conf.get(section="webserver",
                                            key="instance_name",
                                            fallback="Airflow")
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get(
        'core', 'SQL_ALCHEMY_CONN')
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')

    cookie_samesite_config = conf.get('webserver', 'COOKIE_SAMESITE')
    if cookie_samesite_config == "":
        warnings.warn(
            "Old deprecated value found for `cookie_samesite` option in `[webserver]` section. "
            "Using `Lax` instead. Change the value to `Lax` in airflow.cfg to remove this warning.",
            DeprecationWarning,
        )
        cookie_samesite_config = "Lax"
    flask_app.config['SESSION_COOKIE_SAMESITE'] = cookie_samesite_config

    if config:
        flask_app.config.from_mapping(config)

    if 'SQLALCHEMY_ENGINE_OPTIONS' not in flask_app.config:
        flask_app.config[
            'SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    csrf.init_app(flask_app)

    init_wsgi_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    Cache(app=flask_app,
          config={
              'CACHE_TYPE': 'filesystem',
              'CACHE_DIR': '/tmp'
          })

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_connection_form()
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)
        init_airflow_session_interface(flask_app)
    return flask_app