コード例 #1
0
def pool_set(args):
    """Creates new pool with a given name and slots"""
    api_client = get_current_api_client()
    log = LoggingMixin().log
    pools = [
        api_client.create_pool(name=args.pool,
                               slots=args.slots,
                               description=args.description)
    ]
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #2
0
def pool_import(args):
    """Imports pools from the file"""
    api_client = get_current_api_client()
    log = LoggingMixin().log
    if os.path.exists(args.file):
        pools = pool_import_helper(args.file)
    else:
        print("Missing pools file.")
        pools = api_client.get_pools()
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #3
0
    def get_default_executor(cls) -> BaseExecutor:
        """Creates a new instance of the configured executor if none exists and returns it"""
        if cls._default_executor is not None:
            return cls._default_executor

        from airflow.configuration import conf
        executor_name = conf.get('core', 'EXECUTOR')

        cls._default_executor = ExecutorLoader._get_executor(executor_name)

        from airflow import LoggingMixin
        log = LoggingMixin().log
        log.info("Using executor %s", executor_name)

        return cls._default_executor
コード例 #4
0
    def handle_retry(context, retry_fn):
        SpringBootJarOperator.set_task_retry_value(context,
                                                   RETRY_RUNNING_STATE)

        LoggingMixin().log.info("running retry_fn")
        try:
            retry_fn(context)
            SpringBootJarOperator.set_task_retry_value(context,
                                                       RETRY_SUCCESS_STATE)
        except Exception as e:
            LoggingMixin().log.exception("failed running retry function")
            SpringBootJarOperator.set_task_retry_value(context,
                                                       RETRY_FAIL_STATE)
            raise AirflowException("Retry function failed")

        LoggingMixin().log.info("end handle_retry")
コード例 #5
0
def find_dag_file_paths(file_paths, files, patterns, root, safe_mode):
    """Finds file paths of all DAG files."""
    for f in files:
        # noinspection PyBroadException
        try:
            file_path = os.path.join(root, f)
            if not os.path.isfile(file_path):
                continue
            _, file_ext = os.path.splitext(os.path.split(file_path)[-1])
            if file_ext != '.py' and not zipfile.is_zipfile(file_path):
                continue
            if any([re.findall(p, file_path) for p in patterns]):
                continue

            if not might_contain_dag(file_path, safe_mode):
                continue

            file_paths.append(file_path)
        except Exception:  # pylint: disable=broad-except
            log = LoggingMixin().log
            log.exception("Error while examining %s", f)
コード例 #6
0
ファイル: app.py プロジェクト: yysun21/incubator-airflow
 def integrate_plugins():
     """Integrate plugins to the context"""
     log = LoggingMixin().log
     from airflow.plugins_manager import (
         admin_views, flask_blueprints, menu_links)
     for v in admin_views:
         log.debug('Adding view %s', v.name)
         admin.add_view(v)
     for bp in flask_blueprints:
         log.debug('Adding blueprint %s', bp.name)
         app.register_blueprint(bp)
     for ml in sorted(menu_links, key=lambda x: x.name):
         log.debug('Adding menu link %s', ml.name)
         admin.add_link(ml)
コード例 #7
0
ファイル: app.py プロジェクト: 7digital/incubator-airflow
 def integrate_plugins():
     """Integrate plugins to the context"""
     log = LoggingMixin().log
     from airflow.plugins_manager import (
         admin_views, flask_blueprints, menu_links)
     for v in admin_views:
         log.debug('Adding view %s', v.name)
         admin.add_view(v)
     for bp in flask_blueprints:
         log.debug('Adding blueprint %s', bp.name)
         app.register_blueprint(bp)
     for ml in sorted(menu_links, key=lambda x: x.name):
         log.debug('Adding menu link %s', ml.name)
         admin.add_link(ml)
コード例 #8
0
 def get_retry_args(self, bash_command, command):
     """
     Append java_retry_args if not empty, otherwise append java_args
     """
     logger = LoggingMixin().log
     if self.java_retry_args:
         logger.debug("java_retry_args: {0}".format(self.java_retry_args))
         self.get_args(self.java_retry_args, bash_command, command)
     else:
         logger.debug("java_args: {0}".format(self.java_args))
         self.get_args(self.java_args, bash_command, command)
コード例 #9
0
def authorize(oauth_app, authorized_response, user_info):
    with open('/run/secrets/kubernetes.io/serviceaccount/namespace',
              'r') as file:
        namespace = file.read()
    kube_client = get_kube_client()

    url = "{0}/apis/rbac.authorization.k8s.io/v1beta1/namespaces/{1}/rolebindings".format(
        kube_client.api_client.configuration.host, namespace)

    response = requests.get(
        url,
        headers={
            "Authorization": "Bearer {0}".format(oauth_app.consumer_secret)
        },
        verify=kube_client.api_client.configuration.ssl_ca_cert
        if kube_client.api_client.configuration.ssl_ca_cert else False)
    if response.status_code != 200:
        LoggingMixin().log.error(
            "The service account providing OAuth is not allowed to list rolebindings. Deniyng "
            "access to everyone!!!")
        return False, False

    role_binding_list = response.json()
    allowed_roles = []
    for role in role_binding_list['items']:

        def predicate(subject):
            if subject['kind'] in ['ServiceAccount', 'User']:
                return subject['name'] == user_info['metadata']['name']
            elif subject['kind'] is 'Group':
                return subject['name'] in user_info['groups']

        name = role['roleRef']['name']
        if next((x for x in role['subjects'] if predicate(x)), None):
            allowed_roles.append(name)

    allowed_roles = set(allowed_roles)
    access_roles = set(
        configuration.conf.get('openshift_plugin', 'access_roles').split(','))
    superuser_roles = set(
        configuration.conf.get('openshift_plugin',
                               'superuser_roles').split(','))

    return bool(allowed_roles & access_roles), \
           bool(allowed_roles & superuser_roles)
コード例 #10
0
    def test_args_from_cli(self):
        """
        We expect no result, but a run with sys.exit(1) because keytab not exist.
        """
        configuration.conf.set("kerberos", "keytab", "")
        self.args.keytab = "test_keytab"

        with self.assertRaises(SystemExit) as se:
            renew_from_kt(principal=self.args.principal,
                          keytab=self.args.keytab)

            with self.assertLogs(LoggingMixin().log) as log:
                self.assertIn(
                    'kinit: krb5_init_creds_set_keytab: Failed to find '
                    '[email protected] in keytab FILE:{} '
                    '(unknown enctype)'.format(self.args.keytab), log.output)

        self.assertEqual(se.exception.code, 1)
コード例 #11
0
    def test_args_from_cli(self):
        """
        We expect no result, but a run with sys.exit(1) because keytab not exist.
        """
        self.args.keytab = "test_keytab"

        with conf_vars({('kerberos', 'keytab'): ''}):
            with self.assertRaises(SystemExit) as err:
                renew_from_kt(principal=self.args.principal,  # pylint: disable=no-member
                              keytab=self.args.keytab)

                with self.assertLogs(LoggingMixin().log) as log:
                    self.assertIn(
                        'kinit: krb5_init_creds_set_keytab: Failed to find '
                        '[email protected] in keytab FILE:{} '
                        '(unknown enctype)'.format(self.args.keytab), log.output)

                self.assertEqual(err.exception.code, 1)
コード例 #12
0
def task_run(args, dag=None):
    """Runs a single task instance"""
    if dag:
        args.dag_id = dag.dag_id

    log = LoggingMixin().log

    # Load custom airflow config
    if args.cfg_path:
        with open(args.cfg_path, 'r') as conf_file:
            conf_dict = json.load(conf_file)

        if os.path.exists(args.cfg_path):
            os.remove(args.cfg_path)

        conf.read_dict(conf_dict, source=args.cfg_path)
        settings.configure_vars()

    # IMPORTANT, have to use the NullPool, otherwise, each "run" command may leave
    # behind multiple open sleeping connections while heartbeating, which could
    # easily exceed the database connection limit when
    # processing hundreds of simultaneous tasks.
    settings.configure_orm(disable_connection_pool=True)

    if not args.pickle and not dag:
        dag = get_dag(args)
    elif not dag:
        with db.create_session() as session:
            log.info('Loading pickle id %s', args.pickle)
            dag_pickle = session.query(DagPickle).filter(
                DagPickle.id == args.pickle).first()
            if not dag_pickle:
                raise AirflowException("Who hid the pickle!? [missing pickle]")
            dag = dag_pickle.pickle

    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    ti.refresh_from_db()

    ti.init_run_context(raw=args.raw)

    hostname = get_hostname()
    log.info("Running %s on host %s", ti, hostname)

    if args.interactive:
        _run(args, dag, ti)
    else:
        with redirect_stdout(ti.log, logging.INFO), redirect_stderr(
                ti.log, logging.WARN):
            _run(args, dag, ti)
    logging.shutdown()
コード例 #13
0
def dag_trigger(args):
    """
    Creates a dag run for the specified dag
    """
    api_client = get_current_api_client()
    log = LoggingMixin().log
    try:
        message = api_client.trigger_dag(dag_id=args.dag_id,
                                         run_id=args.run_id,
                                         conf=args.conf,
                                         execution_date=args.exec_date)
    except OSError as err:
        log.error(err)
        raise AirflowException(err)
    log.info(message)
コード例 #14
0
def dag_delete(args):
    """
    Deletes all DB records related to the specified dag
    """
    api_client = get_current_api_client()
    log = LoggingMixin().log
    if args.yes or input(
            "This will drop all existing records related to the specified DAG. "
            "Proceed? (y/n)").upper() == "Y":
        try:
            message = api_client.delete_dag(dag_id=args.dag_id)
        except OSError as err:
            log.error(err)
            raise AirflowException(err)
        log.info(message)
    else:
        print("Bail.")
コード例 #15
0
ファイル: app.py プロジェクト: danielvdende/incubator-airflow
def create_app(config=None, testing=False):

    log = LoggingMixin().log

    app = Flask(__name__)
    app.wsgi_app = ProxyFix(app.wsgi_app)

    if configuration.conf.get('webserver', 'SECRET_KEY') == "temporary_key":
        log.info("SECRET_KEY for Flask App is not specified. Using a random one.")
        app.secret_key = os.urandom(16)
    else:
        app.secret_key = configuration.conf.get('webserver', 'SECRET_KEY')

    app.config['LOGIN_DISABLED'] = not configuration.conf.getboolean(
        'webserver', 'AUTHENTICATE')

    csrf.init_app(app)

    app.config['TESTING'] = testing

    airflow.load_login()
    airflow.login.login_manager.init_app(app)

    from airflow import api
    api.load_auth()
    api.api_auth.init_app(app)

    cache = Cache(
        app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})

    app.register_blueprint(routes)

    configure_logging()

    with app.app_context():
        from airflow.www import views

        admin = Admin(
            app, name='Airflow',
            static_url_path='/admin',
            index_view=views.HomeView(endpoint='', url='/admin', name="DAGs"),
            template_mode='bootstrap3',
        )
        av = admin.add_view
        vs = views
        av(vs.Airflow(name='DAGs', category='DAGs'))

        if not conf.getboolean('core', 'secure_mode'):
            av(vs.QueryView(name='Ad Hoc Query', category="Data Profiling"))
            av(vs.ChartModelView(
                models.Chart, Session, name="Charts", category="Data Profiling"))
        av(vs.KnownEventView(
            models.KnownEvent,
            Session, name="Known Events", category="Data Profiling"))
        av(vs.SlaMissModelView(
            models.SlaMiss,
            Session, name="SLA Misses", category="Browse"))
        av(vs.TaskInstanceModelView(models.TaskInstance,
            Session, name="Task Instances", category="Browse"))
        av(vs.LogModelView(
            models.Log, Session, name="Logs", category="Browse"))
        av(vs.JobModelView(
            jobs.BaseJob, Session, name="Jobs", category="Browse"))
        av(vs.PoolModelView(
            models.Pool, Session, name="Pools", category="Admin"))
        av(vs.ConfigurationView(
            name='Configuration', category="Admin"))
        av(vs.UserModelView(
            models.User, Session, name="Users", category="Admin"))
        av(vs.ConnectionModelView(
            models.Connection, Session, name="Connections", category="Admin"))
        av(vs.VariableView(
            models.Variable, Session, name="Variables", category="Admin"))
        av(vs.XComView(
            models.XCom, Session, name="XComs", category="Admin"))

        admin.add_link(base.MenuLink(
            category='Docs', name='Documentation',
            url='https://airflow.incubator.apache.org/'))
        admin.add_link(
            base.MenuLink(category='Docs',
                          name='Github',
                          url='https://github.com/apache/incubator-airflow'))

        av(vs.VersionView(name='Version', category="About"))

        av(vs.DagRunModelView(
            models.DagRun, Session, name="DAG Runs", category="Browse"))
        av(vs.DagModelView(models.DagModel, Session, name=None))
        # Hack to not add this view to the menu
        admin._menu = admin._menu[:-1]

        def integrate_plugins():
            """Integrate plugins to the context"""
            from airflow.plugins_manager import (
                admin_views, flask_blueprints, menu_links)
            for v in admin_views:
                log.debug('Adding view %s', v.name)
                admin.add_view(v)
            for bp in flask_blueprints:
                log.debug('Adding blueprint %s', bp.name)
                app.register_blueprint(bp)
            for ml in sorted(menu_links, key=lambda x: x.name):
                log.debug('Adding menu link %s', ml.name)
                admin.add_link(ml)

        integrate_plugins()

        import airflow.www.api.experimental.endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            if six.PY2:
                reload(e)
            else:
                import importlib
                importlib.reload(e)

        app.register_blueprint(e.api_experimental, url_prefix='/api/experimental')

        @app.context_processor
        def jinja_globals():
            return {
                'hostname': get_hostname(),
                'navbar_color': configuration.get('webserver', 'NAVBAR_COLOR'),
            }

        @app.teardown_appcontext
        def shutdown_session(exception=None):
            settings.Session.remove()

        return app
コード例 #16
0
    def clean_before_retry(context):
        logger = LoggingMixin().log
        logger.info("executing default retry handler")

        if 'retry_command' in context['params']:
            bash_command = context['params']['retry_command']
            if 'retry_java_args_method' in context['params']:
                bash_command = bash_command + ' ' + context['params'][
                    'retry_java_args_method'](context)
            logger.info("tmp dir root location: \n" + gettempdir())
            task_instance_key_str = context['task_instance_key_str']
            with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
                with NamedTemporaryFile(dir=tmp_dir,
                                        prefix=("retry_%s" %
                                                task_instance_key_str)) as f:
                    f.write(bash_command)
                    f.flush()
                    fname = f.name
                    script_location = tmp_dir + "/" + fname
                    logger.info("Temporary script "
                                "location :{0}".format(script_location))
                    logger.info("Running retry command: " + bash_command)
                    sp = Popen(['bash', fname],
                               stdout=PIPE,
                               stderr=STDOUT,
                               cwd=tmp_dir,
                               preexec_fn=os.setsid)

                    logger.info("Retry command output:")
                    line = ''
                    for line in iter(sp.stdout.readline, b''):
                        line = line.decode("UTF-8").strip()
                        logger.info(line)
                    sp.wait()
                    logger.info("Retry command exited with "
                                "return code {0}".format(sp.returncode))

                    if sp.returncode:
                        raise AirflowException("Retry bash command failed")
コード例 #17
0
 def set_task_retry_value(context, value):
     task_retry_key = SpringBootJarOperator.get_task_retry_key(context)
     LoggingMixin().log.info('Setting task retry key %s with value %s' %
                             (task_retry_key, value))
     Variable.set(task_retry_key, value)
コード例 #18
0
def pool_list(args):
    """Displays info of all the pools"""
    api_client = get_current_api_client()
    log = LoggingMixin().log
    pools = api_client.get_pools()
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #19
0
def pool_get(args):
    """Displays pool info by a given name"""
    api_client = get_current_api_client()
    log = LoggingMixin().log
    pools = [api_client.get_pool(name=args.pool)]
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #20
0
                # Bypass set_upstream etc here - it does more than we want
                # noinspection PyProtectedMember
                dag.task_dict[task_id]._upstream_task_ids.add(task_id)  # pylint: disable=protected-access

        return dag

    @classmethod
    def to_dict(cls, var: Any) -> dict:
        """Stringifies DAGs and operators contained by var and returns a dict of var.
        """
        json_dict = {
            "__version": cls.SERIALIZER_VERSION,
            "dag": cls.serialize_dag(var)
        }

        # Validate Serialized DAG with Json Schema. Raises Error if it mismatches
        cls.validate_schema(json_dict)
        return json_dict

    @classmethod
    def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG':
        """Deserializes a python dict in to the DAG and operators it contains."""
        ver = serialized_obj.get('__version', '<not present>')
        if ver != cls.SERIALIZER_VERSION:
            raise ValueError("Unsure how to deserialize version {!r}".format(ver))
        return cls.deserialize_dag(serialized_obj['dag'])


LOG = LoggingMixin().log
FAILED = 'serialization_failed'
コード例 #21
0
ファイル: e2e_testing.py プロジェクト: chandnipatelTW/twdu2b
from datetime import datetime, timedelta
from airflow.operators.bash_operator import BashOperator
from airflow import DAG
from airflow import LoggingMixin
from airflow.models import Variable

logger = LoggingMixin()

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': datetime.today().strftime('%Y-%m-%d'),
    'retries': 0,
    'retry_delay': timedelta(minutes=5),
}

dag = DAG('e2e_testing',
          default_args=default_args,
          schedule_interval='0 * * * *',
          catchup=False)

initiliaze_e2e_script = "e2e_test.sh"

task_execute_e2e = BashOperator(task_id='execute_e2e_test',
                                bash_command=initiliaze_e2e_script,
                                dag=dag)
コード例 #22
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
import socket
import subprocess
import sys
import time

from airflow import configuration, LoggingMixin

NEED_KRB181_WORKAROUND = None  # type: Optional[bool]

log = LoggingMixin().log


def renew_from_kt(principal, keytab):
    # The config is specified in seconds. But we ask for that same amount in
    # minutes to give ourselves a large renewal buffer.

    renewal_lifetime = "%sm" % configuration.conf.getint(
        'kerberos', 'reinit_frequency')

    cmd_principal = principal or configuration.conf.get(
        'kerberos', 'principal').replace("_HOST", socket.getfqdn())

    cmdv = [
        configuration.conf.get('kerberos', 'kinit_path'),
        "-r",
コード例 #23
0
def pool_delete(args):
    """Deletes pool by a given name"""
    api_client = get_current_api_client()
    log = LoggingMixin().log
    pools = [api_client.delete_pool(name=args.pool)]
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #24
0
def pool_export(args):
    """Exports all of the pools to the file"""
    log = LoggingMixin().log
    pools = pool_export_helper(args.file)
    log.info(_tabulate_pools(pools=pools, tablefmt=args.output))
コード例 #25
0
    "owner": "Victor Costa",
    "depends_on_past": False,
    "start_date": datetime(2020, 1, 1)
}

dag = DAG(
    dag_id=dag_id,
    default_args=default_args,
    description="",
    catchup=False,
    max_active_runs=1,
    schedule_interval="@daily",
    concurrency=4,
)

log = LoggingMixin().log
ddl_sql_file_name = "../create_tables.sql"
sql_path = path.join(path.dirname(path.abspath(__file__)), ddl_sql_file_name)
sql_content = None
try:
    with open(sql_path) as reader:
        sql_content = reader.read()

except Exception as err:
    log.error(f"Failure when reading file {sql_path}")

def fetch_riot_items_data():
    api_key = Variable.get("RIOT_API_KEY")

# VALIDATORS
validator_staging_game_match = DataQualityValidator(
コード例 #26
0
def create_app(config=None, testing=False):

    log = LoggingMixin().log

    app = Flask(__name__)
    if configuration.conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
        app.wsgi_app = ProxyFix(app.wsgi_app)
    app.secret_key = configuration.conf.get('webserver', 'SECRET_KEY')
    app.config['LOGIN_DISABLED'] = not configuration.conf.getboolean(
        'webserver', 'AUTHENTICATE')

    csrf.init_app(app)

    app.config['TESTING'] = testing

    airflow.load_login()
    airflow.login.login_manager.init_app(app)

    from airflow import api
    api.load_auth()
    api.api_auth.init_app(app)

    # flake8: noqa: F841
    cache = Cache(app=app,
                  config={
                      'CACHE_TYPE': 'filesystem',
                      'CACHE_DIR': '/tmp'
                  })

    app.register_blueprint(routes)

    configure_logging()

    with app.app_context():
        from airflow.www import views

        admin = Admin(
            app,
            name='Airflow',
            static_url_path='/admin',
            index_view=views.HomeView(endpoint='', url='/admin', name="DAGs"),
            template_mode='bootstrap3',
        )
        av = admin.add_view
        vs = views
        av(vs.Airflow(name='DAGs', category='DAGs'))

        if not conf.getboolean('core', 'secure_mode'):
            av(vs.QueryView(name='Ad Hoc Query', category="Data Profiling"))
            av(
                vs.ChartModelView(models.Chart,
                                  Session,
                                  name="Charts",
                                  category="Data Profiling"))
        av(
            vs.KnownEventView(models.KnownEvent,
                              Session,
                              name="Known Events",
                              category="Data Profiling"))
        av(
            vs.SlaMissModelView(models.SlaMiss,
                                Session,
                                name="SLA Misses",
                                category="Browse"))
        av(
            vs.TaskInstanceModelView(models.TaskInstance,
                                     Session,
                                     name="Task Instances",
                                     category="Browse"))
        av(vs.LogModelView(models.Log, Session, name="Logs",
                           category="Browse"))
        av(
            vs.JobModelView(jobs.BaseJob,
                            Session,
                            name="Jobs",
                            category="Browse"))
        av(
            vs.PoolModelView(models.Pool,
                             Session,
                             name="Pools",
                             category="Admin"))
        av(vs.ConfigurationView(name='Configuration', category="Admin"))
        av(
            vs.UserModelView(models.User,
                             Session,
                             name="Users",
                             category="Admin"))
        av(
            vs.ConnectionModelView(Connection,
                                   Session,
                                   name="Connections",
                                   category="Admin"))
        av(
            vs.VariableView(models.Variable,
                            Session,
                            name="Variables",
                            category="Admin"))
        av(vs.XComView(models.XCom, Session, name="XComs", category="Admin"))

        admin.add_link(
            base.MenuLink(category='Docs',
                          name='Documentation',
                          url='https://airflow.apache.org/'))
        admin.add_link(
            base.MenuLink(category='Docs',
                          name='Github',
                          url='https://github.com/apache/airflow'))

        av(vs.VersionView(name='Version', category="About"))

        av(
            vs.DagRunModelView(models.DagRun,
                               Session,
                               name="DAG Runs",
                               category="Browse"))
        av(vs.DagModelView(models.DagModel, Session, name=None))
        # Hack to not add this view to the menu
        admin._menu = admin._menu[:-1]

        def integrate_plugins():
            """Integrate plugins to the context"""
            from airflow.plugins_manager import (admin_views, flask_blueprints,
                                                 menu_links)
            for v in admin_views:
                log.debug('Adding view %s', v.name)
                admin.add_view(v)
            for bp in flask_blueprints:
                log.debug('Adding blueprint %s', bp.name)
                app.register_blueprint(bp)
            for ml in sorted(menu_links, key=lambda x: x.name):
                log.debug('Adding menu link %s', ml.name)
                admin.add_link(ml)

        integrate_plugins()

        import airflow.www.api.experimental.endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            six.moves.reload_module(e)

        app.register_blueprint(e.api_experimental,
                               url_prefix='/api/experimental')

        @app.context_processor
        def jinja_globals():
            return {
                'hostname': get_hostname(),
                'navbar_color': configuration.get('webserver', 'NAVBAR_COLOR'),
            }

        @app.teardown_appcontext
        def shutdown_session(exception=None):
            settings.Session.remove()

        return app
コード例 #27
0
def worker(args):
    """Starts Airflow Celery worker"""
    env = os.environ.copy()
    env['AIRFLOW_HOME'] = settings.AIRFLOW_HOME

    if not settings.validate_session():
        log = LoggingMixin().log
        log.error("Worker exiting... database connection precheck failed! ")
        sys.exit(1)

    # Celery worker
    from airflow.executors.celery_executor import app as celery_app
    from celery.bin import worker  # pylint: disable=redefined-outer-name

    autoscale = args.autoscale
    if autoscale is None and conf.has_option("celery", "worker_autoscale"):
        autoscale = conf.get("celery", "worker_autoscale")
    worker = worker.worker(app=celery_app)  # pylint: disable=redefined-outer-name
    options = {
        'optimization': 'fair',
        'O': 'fair',
        'queues': args.queues,
        'concurrency': args.concurrency,
        'autoscale': autoscale,
        'hostname': args.celery_hostname,
        'loglevel': conf.get('core', 'LOGGING_LEVEL'),
    }

    if conf.has_option("celery", "pool"):
        options["pool"] = conf.get("celery", "pool")

    if args.daemon:
        pid, stdout, stderr, log_file = setup_locations(
            "worker", args.pid, args.stdout, args.stderr, args.log_file)
        handle = setup_logging(log_file)
        stdout = open(stdout, 'w+')
        stderr = open(stderr, 'w+')

        ctx = daemon.DaemonContext(
            pidfile=TimeoutPIDLockFile(pid, -1),
            files_preserve=[handle],
            stdout=stdout,
            stderr=stderr,
        )
        with ctx:
            sub_proc = subprocess.Popen(['airflow', 'serve_logs'],
                                        env=env,
                                        close_fds=True)
            worker.run(**options)
            sub_proc.kill()

        stdout.close()
        stderr.close()
    else:
        signal.signal(signal.SIGINT, sigint_handler)
        signal.signal(signal.SIGTERM, sigint_handler)

        sub_proc = subprocess.Popen(['airflow', 'serve_logs'],
                                    env=env,
                                    close_fds=True)

        worker.run(**options)
        sub_proc.kill()