Пример #1
0
    def __init__(self, subdag, executor=DEFAULT_EXECUTOR, *args, **kwargs):
        """
        Yo dawg. This runs a sub dag. By convention, a sub dag's dag_id
        should be prefixed by its parent and a dot. As in `parent.child`.

        :param subdag: the DAG object to run as a subdag of the current DAG.
        :type subdag: airflow.DAG
        :param dag: the parent DAG
        :type subdag: airflow.DAG
        """
        if 'dag' not in kwargs:
            raise AirflowException("Please pass in the `dag` param")
        dag = kwargs['dag']
        super(SubDagOperator, self).__init__(*args, **kwargs)
        if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
            raise AirflowException(
                "The subdag's dag_id should correspond to the parent's "
                "'dag_id.task_id'")
        self.subdag = subdag
        self.executor = executor
Пример #2
0
 def _parse_s3_url(self, s3url):
     parsed_url = urlparse(s3url)
     if not parsed_url.netloc:
         raise AirflowException('Please provide a bucket_name')
     else:
         bucket_name = parsed_url.netloc
         if parsed_url.path[0] == '/':
             key = parsed_url.path[1:]
         else:
             key = parsed_url.path
         return (bucket_name, key)
Пример #3
0
 def execute(self, context=None):
     hook = self.get_db_hook()
     logging.info('Executing SQL check: ' + self.sql2)
     row2 = hook.get_first(hql=self.sql2)
     logging.info('Executing SQL check: ' + self.sql1)
     row1 = hook.get_first(hql=self.sql1)
     if not row2:
         raise AirflowException("The query {q} returned None").format(
             q=self.sql2)
     if not row1:
         raise AirflowException("The query {q} returned None").format(
             q=self.sql1)
     current = dict(zip(self.metrics_sorted, row1))
     reference = dict(zip(self.metrics_sorted, row2))
     ratios = {}
     test_results = {}
     rlog = "Ratio for {0}: {1} \n Ratio threshold : {2}"
     fstr = "'{k}' check failed. {r} is above {tr}"
     estr = "The following tests have failed:\n {0}"
     countstr = "The following {j} tests out of {n} failed:"
     for m in self.metrics_sorted:
         if current[m] == 0 or reference[m] == 0:
             ratio = None
         else:
             ratio = float(max(current[m], reference[m])) / \
                 min(current[m], reference[m])
         logging.info(rlog.format(m, ratio, self.metrics_thresholds[m]))
         ratios[m] = ratio
         test_results[m] = ratio < self.metrics_thresholds[m]
     if not all(test_results.values()):
         failed_tests = [it[0] for it in test_results.items() if not it[1]]
         j = len(failed_tests)
         n = len(self.metrics_sorted)
         logging.warning(countstr.format(**locals()))
         for k in failed_tests:
             logging.warning(
                 fstr.format(k=k,
                             r=ratios[k],
                             tr=self.metrics_thresholds[k]))
         raise AirflowException(estr.format(", ".join(failed_tests)))
     logging.info("All tests have passed")
Пример #4
0
def process_subdir(subdir):
    dags_folder = configuration.get("core", "DAGS_FOLDER")
    dags_folder = os.path.expanduser(dags_folder)
    if subdir:
        subdir = os.path.expanduser(subdir)
        if "DAGS_FOLDER" in subdir:
            subdir = subdir.replace("DAGS_FOLDER", dags_folder)
        if dags_folder not in subdir:
            raise AirflowException(
                "subdir has to be part of your DAGS_FOLDER as defined in your "
                "airflow.cfg")
        return subdir
Пример #5
0
 def kill(self):
     session = settings.Session()
     job = session.query(BaseJob).filter(BaseJob.id == self.id).first()
     job.end_date = datetime.now()
     try:
         self.on_kill()
     except:
         logging.error('on_kill() method failed')
     session.merge(job)
     session.commit()
     session.close()
     raise AirflowException("Job shut down externally.")
Пример #6
0
    def _get_sessions(self):
        method = "GET"
        endpoint = "sessions"
        response = self._http_rest_call(method=method, endpoint=endpoint)

        if response.status_code in self.acceptable_response_codes:
            return response.json()["sessions"]
        else:
            raise AirflowException("Call to get sessions didn't return " +
                                   str(self.acceptable_response_codes) +
                                   ". Returned '" + str(response.status_code) +
                                   "'.")
Пример #7
0
def parse_gcs_url(gsurl):
    """
    Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a
    tuple containing the corresponding bucket and blob.
    """
    parsed_url = urlparse(gsurl)
    if not parsed_url.netloc:
        raise AirflowException('Please provide a bucket name')
    else:
        bucket = parsed_url.netloc
        blob = parsed_url.path.strip('/')
        return (bucket, blob)
Пример #8
0
 def get_connections(cls, conn_id):
     session = settings.Session()
     db = (
         session.query(Connection)
         .filter(Connection.conn_id == conn_id)
         .all()
     )
     if not db:
         raise AirflowException(
             "The conn_id `{0}` isn't defined".format(conn_id))
     session.expunge_all()
     session.close()
     return db
Пример #9
0
def test(args):
    args.execution_date = dateutil.parser.parse(args.execution_date)
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)

    if args.dry_run:
        ti.dry_run()
    else:
        ti.run(force=True, ignore_dependencies=True, test_mode=True)
Пример #10
0
def set_is_paused(is_paused, args):
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]

    session = settings.Session()
    dm = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).first()
    dm.is_paused = is_paused
    session.commit()

    msg = "Dag: {}, paused: {}".format(dag, str(dag.is_paused))
    print(msg)
Пример #11
0
def process_subdir(subdir):
    dags_folder = configuration.get("core", "DAGS_FOLDER")
    dags_folder = os.path.expanduser(dags_folder)
    if subdir:
        if "DAGS_FOLDER" in subdir:
            subdir = subdir.replace("DAGS_FOLDER", dags_folder)
        subdir = os.path.abspath(os.path.expanduser(subdir))
        if dags_folder.rstrip('/') not in subdir.rstrip('/'):
            raise AirflowException(
                "subdir has to be part of your DAGS_FOLDER as defined in your "
                "airflow.cfg. DAGS_FOLDER is {df} and subdir is {sd}".format(
                    df=dags_folder, sd=subdir))
        return subdir
Пример #12
0
 def execute(self, **kwargs):
     """
     SlackAPIOperator calls will not fail even if the call is not unsuccessful.
     It should not prevent a DAG from completing in success
     """
     if not self.api_params:
         self.construct_api_call_params()
     sc = SlackClient(self.token)
     rc = json.loads(
         sc.api_call(self.method, **self.api_params).decode('utf-8'))
     if not rc['ok']:
         logging.error("Slack API call failed ({})".format(rc['error']))
         raise AirflowException("Slack API call failed: ({})".format(
             rc['error']))
Пример #13
0
    def _get_session_statements(self, session_id):
        method = "GET"
        endpoint = "sessions/" + str(session_id) + "/statements"
        response = self._http_rest_call(method=method, endpoint=endpoint)

        if response.status_code in self.acceptable_response_codes:
            response_json = response.json()
            statements = response_json["statements"]
            return statements
        else:
            raise AirflowException(
                "Call to get the session statement response didn't return " +
                str(self.acceptable_response_codes) + ". Returned '" +
                str(response.status_code) + "'.")
Пример #14
0
    def __init__(self, subdag, executor=DEFAULT_EXECUTOR, *args, **kwargs):
        """
        Yo dawg. This runs a sub dag. By convention, a sub dag's dag_id
        should be prefixed by its parent and a dot. As in `parent.child`.

        :param subdag: the DAG object to run as a subdag of the current DAG.
        :type subdag: airflow.DAG
        :param dag: the parent DAG
        :type subdag: airflow.DAG
        """
        if 'dag' not in kwargs:
            raise AirflowException("Please pass in the `dag` param")
        dag = kwargs['dag']
        super(SubDagOperator, self).__init__(*args, **kwargs)
        if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id:
            raise AirflowException(
                "The subdag's dag_id should have the form "
                "'{{parent_dag_id}}.{{this_task_id}}'. Expected "
                "'{d}.{t}'; received '{rcvd}'.".format(d=dag.dag_id,
                                                       t=kwargs['task_id'],
                                                       rcvd=subdag.dag_id))
        self.subdag = subdag
        self.executor = executor
Пример #15
0
    def __init__(self, conn_id, sql, *args, **kwargs):

        super(SqlSensor, self).__init__(*args, **kwargs)

        self.sql = sql
        self.conn_id = conn_id

        session = settings.Session()
        db = session.query(DB).filter(DB.conn_id == conn_id).first()
        if not db:
            raise AirflowException("conn_id doesn't exist in the repository")
        self.hook = db.get_hook()
        session.commit()
        session.close()
Пример #16
0
 def execute(self, context=None):
     logging.info('Executing SQL check: ' + self.sql)
     records = self.get_db_hook().get_first(hql=self.sql)
     if not records:
         raise AirflowException("The query returned None")
     test_results = []
     except_temp = ("Test failed.\nPass value:{self.pass_value}\n"
                    "Query:\n{self.sql}\nResults:\n{records!s}")
     if not self.is_numeric_value_check:
         tests = [str(r) == self.pass_value for r in records]
     elif self.is_numeric_value_check:
         try:
             num_rec = [float(r) for r in records]
         except (ValueError, TypeError) as e:
             cvestr = "Converting a result to float failed.\n"
             raise AirflowException(cvestr+except_temp.format(**locals()))
         if self.has_tolerance:
             tests = [
                 r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol)
                 for r in num_rec]
         else:
             tests = [r == self.pass_value for r in num_rec]
     if not all(tests):
         raise AirflowException(except_temp.format(**locals()))
Пример #17
0
def task_state(args):
    """
    Returns the state of a TaskInstance at the command line.

    >>> airflow task_state tutorial sleep 2015-01-01
    success
    """
    args.execution_date = dateutil.parser.parse(args.execution_date)
    dagbag = DagBag(args.subdir)
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    print(ti.current_state())
Пример #18
0
def render(args):
    args.execution_date = dateutil.parser.parse(args.execution_date)
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)
    ti.render_templates()
    for attr in task.__class__.template_fields:
        print(textwrap.dedent("""\
        # ----------------------------------------------------------
        # property: {}
        # ----------------------------------------------------------
        {}
        """.format(attr, getattr(task, attr))))
Пример #19
0
    def execute(self, context):
        args = self.cls.parse(self.args)
        self.cmd = self.cls.create(**args)
        context['task_instance'].xcom_push(key='qbol_cmd_id', value=self.cmd.id)
        logging.info("Qubole command created with Id: {0} and Status: {1}".format(str(self.cmd.id), self.cmd.status))

        while not Command.is_done(self.cmd.status):
            time.sleep(Qubole.poll_interval)
            self.cmd = self.cls.find(self.cmd.id)
            logging.info("Command Id: {0} and Status: {1}".format(str(self.cmd.id), self.cmd.status))

        if self.kwargs.has_key('fetch_logs') and self.kwargs['fetch_logs'] == True:
            logging.info("Logs for Command Id: {0} \n{1}".format(str(self.cmd.id), self.cmd.get_log()))

        if self.cmd.status != 'done':
            raise AirflowException('Command Id: {0} failed with Status: {1}'.format(self.cmd.id, self.cmd.status))
Пример #20
0
def test(args):
    args.execution_date = dateutil.parser.parse(args.execution_date)
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]
    task = dag.get_task(task_id=args.task_id)
    # Add CLI provided task_params to task.params
    if args.task_params:
        passed_in_params = json.loads(args.task_params)
        task.params.update(passed_in_params)
    ti = TaskInstance(task, args.execution_date)

    if args.dry_run:
        ti.dry_run()
    else:
        ti.run(force=True, ignore_dependencies=True, test_mode=True)
Пример #21
0
 def __init__(
         self, bucket_name,
         prefix, delimiter='/',
         s3_conn_id='s3_default',
         *args, **kwargs):
     super(S3PrefixSensor, self).__init__(*args, **kwargs)
     session = settings.Session()
     db = session.query(DB).filter(DB.conn_id == s3_conn_id).first()
     if not db:
         raise AirflowException("conn_id doesn't exist in the repository")
     # Parse
     self.bucket_name = bucket_name
     self.prefix = prefix
     self.delimiter = delimiter
     self.full_url = "s3://" + bucket_name + '/' + prefix
     self.s3_conn_id = s3_conn_id
     session.commit()
     session.close()
    def execute(self, context):

        url = '{endpoint}{pageid}/components/{component_id}.json'.format(
            endpoint=self.endpoint,
            pageid=self.pageid,
            component_id=self.component_id)

        header = {
            'Content-Type': 'application/json',
            'Authorization': 'OAuth {apikey}'.format(apikey=self.apikey)
        }
        body = json.dumps({"component": {"status": self.component_status}})

        response = requests.request("PATCH", url, data=body, headers=header)
        if response.status_code >= 400:
            logging.error('StatusPage API call failed: %s %s',
                          response.status_code, response.reason)
            raise AirflowException('StatusPage API call failed: %s %s' %
                                   (response.status_code, response.reason))
Пример #23
0
    def execute(self, context):

        url = '{endpoint}/rest/api/latest/issue/{ticket_id}/transitions?expand=transitions.fields'.format(
            endpoint=self.endpoint, ticket_id=self.ticket_id)

        header = {
            'Content-Type':
            'application/json',
            'Authorization':
            'Basic {creds}'.format(creds=base64.b64encode(self.api_user + ":" +
                                                          self.api_password))
        }
        body = json.dumps({"transition": {"id": self.transition_id}})

        response = requests.request("POST", url, data=body, headers=header)
        if response.status_code >= 400:
            logging.error('JIRA API call failed: %s %s', response.status_code,
                          response.reason)
            raise AirflowException('JIRA API call failed: %s %s' %
                                   (response.status_code, response.reason))
Пример #24
0
    def _submit_spark_script(self, session_id):
        method = "POST"
        endpoint = "sessions/" + str(session_id) + "/statements"

        logging.info("Executing Spark Script: \n" + str(self.spark_script))

        data = {'code': textwrap.dedent(self.spark_script)}

        response = self._http_rest_call(method=method,
                                        endpoint=endpoint,
                                        data=data)

        if response.status_code in self.acceptable_response_codes:
            response_json = response.json()
            return response_json["id"], response_json["state"]
        else:
            raise AirflowException(
                "Call to create a new statement didn't return " +
                str(self.acceptable_response_codes) + ". Returned '" +
                str(response.status_code) + "'.")
Пример #25
0
    def run_cli(self, pig, verbose=True):
        """
        Run an pig script using the pig cli

        >>> ph = PigCliHook()
        >>> result = ph.run_cli("ls /;")
        >>> ("hdfs://" in result)
        True
        """

        with TemporaryDirectory(prefix='airflow_pigop_') as tmp_dir:
            with NamedTemporaryFile(dir=tmp_dir) as f:
                f.write(pig)
                f.flush()
                fname = f.name
                pig_bin = 'pig'
                cmd_extra = []

                pig_cmd = [pig_bin, '-f', fname] + cmd_extra

                if self.pig_properties:
                    pig_properties_list = self.pig_properties.split()
                    pig_cmd.extend(pig_properties_list)
                if verbose:
                    logging.info(" ".join(pig_cmd))
                sp = subprocess.Popen(pig_cmd,
                                      stdout=subprocess.PIPE,
                                      stderr=subprocess.STDOUT,
                                      cwd=tmp_dir)
                self.sp = sp
                stdout = ''
                for line in iter(sp.stdout.readline, ''):
                    stdout += line
                    if verbose:
                        logging.info(line.strip())
                sp.wait()

                if sp.returncode:
                    raise AirflowException(stdout)

                return stdout
Пример #26
0
    def max_partition(self, schema, table_name, field=None, filter=None):
        '''
        Returns the maximum value for all partitions in a table. Works only
        for tables that have a single partition key. For subpartitioned
        table, we recommend using signal tables.

        >>> hh = HiveMetastoreHook()
        >>> t = 'static_babynames_partitioned'
        >>> hh.max_partition(schema='airflow', table_name=t)
        '2015-01-01'
        '''
        parts = self.get_partitions(schema, table_name, filter)
        if not parts:
            return None
        elif len(parts[0]) == 1:
            field = list(parts[0].keys())[0]
        elif not field:
            raise AirflowException("Please specify the field you want the max "
                                   "value for")

        return max([p[field] for p in parts])
Пример #27
0
def backfill(args):
    logging.basicConfig(
        level=settings.LOGGING_LEVEL,
        format=settings.SIMPLE_LOG_FORMAT)
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]

    if args.start_date:
        args.start_date = dateutil.parser.parse(args.start_date)
    if args.end_date:
        args.end_date = dateutil.parser.parse(args.end_date)

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(
            task_regex=args.task_regex,
            include_upstream=not args.ignore_dependencies)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id,
                                                 args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        dag.run(
            start_date=args.start_date,
            end_date=args.end_date,
            mark_success=args.mark_success,
            include_adhoc=args.include_adhoc,
            local=args.local,
            donot_pickle=(args.donot_pickle or configuration.getboolean('core', 'donot_pickle')),
            ignore_dependencies=args.ignore_dependencies,
            pool=args.pool)
Пример #28
0
def load_login():
    auth_backend = 'airflow.default_login'
    try:
        if conf.getboolean('webserver', 'AUTHENTICATE'):
            auth_backend = conf.get('webserver', 'auth_backend')
    except conf.AirflowConfigException:
        if conf.getboolean('webserver', 'AUTHENTICATE'):
            logging.warning(
                "auth_backend not found in webserver config reverting to "
                "*deprecated*  behavior of importing airflow_login")
            auth_backend = "airflow_login"

    try:
        global login
        login = import_module(auth_backend)
    except ImportError as err:
        logging.critical(
            "Cannot import authentication module %s. "
            "Please correct your authentication backend or disable authentication: %s",
            auth_backend, err)
        if conf.getboolean('webserver', 'AUTHENTICATE'):
            raise AirflowException("Failed to import authentication backend")
Пример #29
0
    def execute(self, context):
        """
        Execute the bash command in a temporary directory
        which will be cleaned afterwards
        """
        bash_command = self.bash_command
        logging.info("tmp dir root location: \n" + gettempdir())
        with TemporaryDirectory(prefix='airflowtmp') as tmp_dir:
            with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f:

                f.write(bytes(bash_command, 'utf_8'))
                f.flush()
                fname = f.name
                script_location = tmp_dir + "/" + fname
                logging.info("Temporary script "
                             "location :{0}".format(script_location))
                logging.info("Running command: " + bash_command)
                sp = Popen(['bash', fname],
                           stdout=PIPE,
                           stderr=STDOUT,
                           cwd=tmp_dir,
                           env=self.env)

                self.sp = sp

                logging.info("Output:")
                line = ''
                for line in iter(sp.stdout.readline, b''):
                    line = line.decode(self.output_encoding).strip()
                    logging.info(line)
                sp.wait()
                logging.info("Command exited with "
                             "return code {0}".format(sp.returncode))

                if sp.returncode:
                    raise AirflowException("Bash command failed")

        if self.xcom_push_flag:
            return line
Пример #30
0
    def run_cli(self, hql, schema=None):
        '''
        Run an hql statement using the hive cli

        >>> hh = HiveCliHook()
        >>> result = hh.run_cli("USE airflow;")
        >>> ("OK" in result)
        True
        '''
        if schema:
            hql = "USE {schema};\n{hql}".format(**locals())

        with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir:
            with NamedTemporaryFile(dir=tmp_dir) as f:
                f.write(hql)
                f.flush()
                fname = f.name
                hive_cmd = ['hive', '-f', fname]
                if self.hive_cli_params:
                    hive_params_list = self.hive_cli_params.split()
                    hive_cmd.extend(hive_params_list)
                sp = subprocess.Popen(hive_cmd,
                                      stdout=subprocess.PIPE,
                                      stderr=subprocess.STDOUT,
                                      cwd=tmp_dir)
                all_err = ''
                self.sp = sp
                stdout = ""
                for line in iter(sp.stdout.readline, ''):
                    stdout += line
                    logging.info(line.strip())
                sp.wait()

                if sp.returncode:
                    raise AirflowException(all_err)

                return stdout