def __init__(self, subdag, executor=DEFAULT_EXECUTOR, *args, **kwargs): """ Yo dawg. This runs a sub dag. By convention, a sub dag's dag_id should be prefixed by its parent and a dot. As in `parent.child`. :param subdag: the DAG object to run as a subdag of the current DAG. :type subdag: airflow.DAG :param dag: the parent DAG :type subdag: airflow.DAG """ if 'dag' not in kwargs: raise AirflowException("Please pass in the `dag` param") dag = kwargs['dag'] super(SubDagOperator, self).__init__(*args, **kwargs) if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id: raise AirflowException( "The subdag's dag_id should correspond to the parent's " "'dag_id.task_id'") self.subdag = subdag self.executor = executor
def _parse_s3_url(self, s3url): parsed_url = urlparse(s3url) if not parsed_url.netloc: raise AirflowException('Please provide a bucket_name') else: bucket_name = parsed_url.netloc if parsed_url.path[0] == '/': key = parsed_url.path[1:] else: key = parsed_url.path return (bucket_name, key)
def execute(self, context=None): hook = self.get_db_hook() logging.info('Executing SQL check: ' + self.sql2) row2 = hook.get_first(hql=self.sql2) logging.info('Executing SQL check: ' + self.sql1) row1 = hook.get_first(hql=self.sql1) if not row2: raise AirflowException("The query {q} returned None").format( q=self.sql2) if not row1: raise AirflowException("The query {q} returned None").format( q=self.sql1) current = dict(zip(self.metrics_sorted, row1)) reference = dict(zip(self.metrics_sorted, row2)) ratios = {} test_results = {} rlog = "Ratio for {0}: {1} \n Ratio threshold : {2}" fstr = "'{k}' check failed. {r} is above {tr}" estr = "The following tests have failed:\n {0}" countstr = "The following {j} tests out of {n} failed:" for m in self.metrics_sorted: if current[m] == 0 or reference[m] == 0: ratio = None else: ratio = float(max(current[m], reference[m])) / \ min(current[m], reference[m]) logging.info(rlog.format(m, ratio, self.metrics_thresholds[m])) ratios[m] = ratio test_results[m] = ratio < self.metrics_thresholds[m] if not all(test_results.values()): failed_tests = [it[0] for it in test_results.items() if not it[1]] j = len(failed_tests) n = len(self.metrics_sorted) logging.warning(countstr.format(**locals())) for k in failed_tests: logging.warning( fstr.format(k=k, r=ratios[k], tr=self.metrics_thresholds[k])) raise AirflowException(estr.format(", ".join(failed_tests))) logging.info("All tests have passed")
def process_subdir(subdir): dags_folder = configuration.get("core", "DAGS_FOLDER") dags_folder = os.path.expanduser(dags_folder) if subdir: subdir = os.path.expanduser(subdir) if "DAGS_FOLDER" in subdir: subdir = subdir.replace("DAGS_FOLDER", dags_folder) if dags_folder not in subdir: raise AirflowException( "subdir has to be part of your DAGS_FOLDER as defined in your " "airflow.cfg") return subdir
def kill(self): session = settings.Session() job = session.query(BaseJob).filter(BaseJob.id == self.id).first() job.end_date = datetime.now() try: self.on_kill() except: logging.error('on_kill() method failed') session.merge(job) session.commit() session.close() raise AirflowException("Job shut down externally.")
def _get_sessions(self): method = "GET" endpoint = "sessions" response = self._http_rest_call(method=method, endpoint=endpoint) if response.status_code in self.acceptable_response_codes: return response.json()["sessions"] else: raise AirflowException("Call to get sessions didn't return " + str(self.acceptable_response_codes) + ". Returned '" + str(response.status_code) + "'.")
def parse_gcs_url(gsurl): """ Given a Google Cloud Storage URL (gs://<bucket>/<blob>), returns a tuple containing the corresponding bucket and blob. """ parsed_url = urlparse(gsurl) if not parsed_url.netloc: raise AirflowException('Please provide a bucket name') else: bucket = parsed_url.netloc blob = parsed_url.path.strip('/') return (bucket, blob)
def get_connections(cls, conn_id): session = settings.Session() db = ( session.query(Connection) .filter(Connection.conn_id == conn_id) .all() ) if not db: raise AirflowException( "The conn_id `{0}` isn't defined".format(conn_id)) session.expunge_all() session.close() return db
def test(args): args.execution_date = dateutil.parser.parse(args.execution_date) dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) if args.dry_run: ti.dry_run() else: ti.run(force=True, ignore_dependencies=True, test_mode=True)
def set_is_paused(is_paused, args): dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] session = settings.Session() dm = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).first() dm.is_paused = is_paused session.commit() msg = "Dag: {}, paused: {}".format(dag, str(dag.is_paused)) print(msg)
def process_subdir(subdir): dags_folder = configuration.get("core", "DAGS_FOLDER") dags_folder = os.path.expanduser(dags_folder) if subdir: if "DAGS_FOLDER" in subdir: subdir = subdir.replace("DAGS_FOLDER", dags_folder) subdir = os.path.abspath(os.path.expanduser(subdir)) if dags_folder.rstrip('/') not in subdir.rstrip('/'): raise AirflowException( "subdir has to be part of your DAGS_FOLDER as defined in your " "airflow.cfg. DAGS_FOLDER is {df} and subdir is {sd}".format( df=dags_folder, sd=subdir)) return subdir
def execute(self, **kwargs): """ SlackAPIOperator calls will not fail even if the call is not unsuccessful. It should not prevent a DAG from completing in success """ if not self.api_params: self.construct_api_call_params() sc = SlackClient(self.token) rc = json.loads( sc.api_call(self.method, **self.api_params).decode('utf-8')) if not rc['ok']: logging.error("Slack API call failed ({})".format(rc['error'])) raise AirflowException("Slack API call failed: ({})".format( rc['error']))
def _get_session_statements(self, session_id): method = "GET" endpoint = "sessions/" + str(session_id) + "/statements" response = self._http_rest_call(method=method, endpoint=endpoint) if response.status_code in self.acceptable_response_codes: response_json = response.json() statements = response_json["statements"] return statements else: raise AirflowException( "Call to get the session statement response didn't return " + str(self.acceptable_response_codes) + ". Returned '" + str(response.status_code) + "'.")
def __init__(self, subdag, executor=DEFAULT_EXECUTOR, *args, **kwargs): """ Yo dawg. This runs a sub dag. By convention, a sub dag's dag_id should be prefixed by its parent and a dot. As in `parent.child`. :param subdag: the DAG object to run as a subdag of the current DAG. :type subdag: airflow.DAG :param dag: the parent DAG :type subdag: airflow.DAG """ if 'dag' not in kwargs: raise AirflowException("Please pass in the `dag` param") dag = kwargs['dag'] super(SubDagOperator, self).__init__(*args, **kwargs) if dag.dag_id + '.' + kwargs['task_id'] != subdag.dag_id: raise AirflowException( "The subdag's dag_id should have the form " "'{{parent_dag_id}}.{{this_task_id}}'. Expected " "'{d}.{t}'; received '{rcvd}'.".format(d=dag.dag_id, t=kwargs['task_id'], rcvd=subdag.dag_id)) self.subdag = subdag self.executor = executor
def __init__(self, conn_id, sql, *args, **kwargs): super(SqlSensor, self).__init__(*args, **kwargs) self.sql = sql self.conn_id = conn_id session = settings.Session() db = session.query(DB).filter(DB.conn_id == conn_id).first() if not db: raise AirflowException("conn_id doesn't exist in the repository") self.hook = db.get_hook() session.commit() session.close()
def execute(self, context=None): logging.info('Executing SQL check: ' + self.sql) records = self.get_db_hook().get_first(hql=self.sql) if not records: raise AirflowException("The query returned None") test_results = [] except_temp = ("Test failed.\nPass value:{self.pass_value}\n" "Query:\n{self.sql}\nResults:\n{records!s}") if not self.is_numeric_value_check: tests = [str(r) == self.pass_value for r in records] elif self.is_numeric_value_check: try: num_rec = [float(r) for r in records] except (ValueError, TypeError) as e: cvestr = "Converting a result to float failed.\n" raise AirflowException(cvestr+except_temp.format(**locals())) if self.has_tolerance: tests = [ r / (1 + self.tol) <= self.pass_value <= r / (1 - self.tol) for r in num_rec] else: tests = [r == self.pass_value for r in num_rec] if not all(tests): raise AirflowException(except_temp.format(**locals()))
def task_state(args): """ Returns the state of a TaskInstance at the command line. >>> airflow task_state tutorial sleep 2015-01-01 success """ args.execution_date = dateutil.parser.parse(args.execution_date) dagbag = DagBag(args.subdir) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) print(ti.current_state())
def render(args): args.execution_date = dateutil.parser.parse(args.execution_date) dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) ti = TaskInstance(task, args.execution_date) ti.render_templates() for attr in task.__class__.template_fields: print(textwrap.dedent("""\ # ---------------------------------------------------------- # property: {} # ---------------------------------------------------------- {} """.format(attr, getattr(task, attr))))
def execute(self, context): args = self.cls.parse(self.args) self.cmd = self.cls.create(**args) context['task_instance'].xcom_push(key='qbol_cmd_id', value=self.cmd.id) logging.info("Qubole command created with Id: {0} and Status: {1}".format(str(self.cmd.id), self.cmd.status)) while not Command.is_done(self.cmd.status): time.sleep(Qubole.poll_interval) self.cmd = self.cls.find(self.cmd.id) logging.info("Command Id: {0} and Status: {1}".format(str(self.cmd.id), self.cmd.status)) if self.kwargs.has_key('fetch_logs') and self.kwargs['fetch_logs'] == True: logging.info("Logs for Command Id: {0} \n{1}".format(str(self.cmd.id), self.cmd.get_log())) if self.cmd.status != 'done': raise AirflowException('Command Id: {0} failed with Status: {1}'.format(self.cmd.id, self.cmd.status))
def test(args): args.execution_date = dateutil.parser.parse(args.execution_date) dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] task = dag.get_task(task_id=args.task_id) # Add CLI provided task_params to task.params if args.task_params: passed_in_params = json.loads(args.task_params) task.params.update(passed_in_params) ti = TaskInstance(task, args.execution_date) if args.dry_run: ti.dry_run() else: ti.run(force=True, ignore_dependencies=True, test_mode=True)
def __init__( self, bucket_name, prefix, delimiter='/', s3_conn_id='s3_default', *args, **kwargs): super(S3PrefixSensor, self).__init__(*args, **kwargs) session = settings.Session() db = session.query(DB).filter(DB.conn_id == s3_conn_id).first() if not db: raise AirflowException("conn_id doesn't exist in the repository") # Parse self.bucket_name = bucket_name self.prefix = prefix self.delimiter = delimiter self.full_url = "s3://" + bucket_name + '/' + prefix self.s3_conn_id = s3_conn_id session.commit() session.close()
def execute(self, context): url = '{endpoint}{pageid}/components/{component_id}.json'.format( endpoint=self.endpoint, pageid=self.pageid, component_id=self.component_id) header = { 'Content-Type': 'application/json', 'Authorization': 'OAuth {apikey}'.format(apikey=self.apikey) } body = json.dumps({"component": {"status": self.component_status}}) response = requests.request("PATCH", url, data=body, headers=header) if response.status_code >= 400: logging.error('StatusPage API call failed: %s %s', response.status_code, response.reason) raise AirflowException('StatusPage API call failed: %s %s' % (response.status_code, response.reason))
def execute(self, context): url = '{endpoint}/rest/api/latest/issue/{ticket_id}/transitions?expand=transitions.fields'.format( endpoint=self.endpoint, ticket_id=self.ticket_id) header = { 'Content-Type': 'application/json', 'Authorization': 'Basic {creds}'.format(creds=base64.b64encode(self.api_user + ":" + self.api_password)) } body = json.dumps({"transition": {"id": self.transition_id}}) response = requests.request("POST", url, data=body, headers=header) if response.status_code >= 400: logging.error('JIRA API call failed: %s %s', response.status_code, response.reason) raise AirflowException('JIRA API call failed: %s %s' % (response.status_code, response.reason))
def _submit_spark_script(self, session_id): method = "POST" endpoint = "sessions/" + str(session_id) + "/statements" logging.info("Executing Spark Script: \n" + str(self.spark_script)) data = {'code': textwrap.dedent(self.spark_script)} response = self._http_rest_call(method=method, endpoint=endpoint, data=data) if response.status_code in self.acceptable_response_codes: response_json = response.json() return response_json["id"], response_json["state"] else: raise AirflowException( "Call to create a new statement didn't return " + str(self.acceptable_response_codes) + ". Returned '" + str(response.status_code) + "'.")
def run_cli(self, pig, verbose=True): """ Run an pig script using the pig cli >>> ph = PigCliHook() >>> result = ph.run_cli("ls /;") >>> ("hdfs://" in result) True """ with TemporaryDirectory(prefix='airflow_pigop_') as tmp_dir: with NamedTemporaryFile(dir=tmp_dir) as f: f.write(pig) f.flush() fname = f.name pig_bin = 'pig' cmd_extra = [] pig_cmd = [pig_bin, '-f', fname] + cmd_extra if self.pig_properties: pig_properties_list = self.pig_properties.split() pig_cmd.extend(pig_properties_list) if verbose: logging.info(" ".join(pig_cmd)) sp = subprocess.Popen(pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir) self.sp = sp stdout = '' for line in iter(sp.stdout.readline, ''): stdout += line if verbose: logging.info(line.strip()) sp.wait() if sp.returncode: raise AirflowException(stdout) return stdout
def max_partition(self, schema, table_name, field=None, filter=None): ''' Returns the maximum value for all partitions in a table. Works only for tables that have a single partition key. For subpartitioned table, we recommend using signal tables. >>> hh = HiveMetastoreHook() >>> t = 'static_babynames_partitioned' >>> hh.max_partition(schema='airflow', table_name=t) '2015-01-01' ''' parts = self.get_partitions(schema, table_name, filter) if not parts: return None elif len(parts[0]) == 1: field = list(parts[0].keys())[0] elif not field: raise AirflowException("Please specify the field you want the max " "value for") return max([p[field] for p in parts])
def backfill(args): logging.basicConfig( level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) dagbag = DagBag(process_subdir(args.subdir)) if args.dag_id not in dagbag.dags: raise AirflowException('dag_id could not be found') dag = dagbag.dags[args.dag_id] if args.start_date: args.start_date = dateutil.parser.parse(args.start_date) if args.end_date: args.end_date = dateutil.parser.parse(args.end_date) # If only one date is passed, using same as start and end args.end_date = args.end_date or args.start_date args.start_date = args.start_date or args.end_date if args.task_regex: dag = dag.sub_dag( task_regex=args.task_regex, include_upstream=not args.ignore_dependencies) if args.dry_run: print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date)) for task in dag.tasks: print("Task {0}".format(task.task_id)) ti = TaskInstance(task, args.start_date) ti.dry_run() else: dag.run( start_date=args.start_date, end_date=args.end_date, mark_success=args.mark_success, include_adhoc=args.include_adhoc, local=args.local, donot_pickle=(args.donot_pickle or configuration.getboolean('core', 'donot_pickle')), ignore_dependencies=args.ignore_dependencies, pool=args.pool)
def load_login(): auth_backend = 'airflow.default_login' try: if conf.getboolean('webserver', 'AUTHENTICATE'): auth_backend = conf.get('webserver', 'auth_backend') except conf.AirflowConfigException: if conf.getboolean('webserver', 'AUTHENTICATE'): logging.warning( "auth_backend not found in webserver config reverting to " "*deprecated* behavior of importing airflow_login") auth_backend = "airflow_login" try: global login login = import_module(auth_backend) except ImportError as err: logging.critical( "Cannot import authentication module %s. " "Please correct your authentication backend or disable authentication: %s", auth_backend, err) if conf.getboolean('webserver', 'AUTHENTICATE'): raise AirflowException("Failed to import authentication backend")
def execute(self, context): """ Execute the bash command in a temporary directory which will be cleaned afterwards """ bash_command = self.bash_command logging.info("tmp dir root location: \n" + gettempdir()) with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: f.write(bytes(bash_command, 'utf_8')) f.flush() fname = f.name script_location = tmp_dir + "/" + fname logging.info("Temporary script " "location :{0}".format(script_location)) logging.info("Running command: " + bash_command) sp = Popen(['bash', fname], stdout=PIPE, stderr=STDOUT, cwd=tmp_dir, env=self.env) self.sp = sp logging.info("Output:") line = '' for line in iter(sp.stdout.readline, b''): line = line.decode(self.output_encoding).strip() logging.info(line) sp.wait() logging.info("Command exited with " "return code {0}".format(sp.returncode)) if sp.returncode: raise AirflowException("Bash command failed") if self.xcom_push_flag: return line
def run_cli(self, hql, schema=None): ''' Run an hql statement using the hive cli >>> hh = HiveCliHook() >>> result = hh.run_cli("USE airflow;") >>> ("OK" in result) True ''' if schema: hql = "USE {schema};\n{hql}".format(**locals()) with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir: with NamedTemporaryFile(dir=tmp_dir) as f: f.write(hql) f.flush() fname = f.name hive_cmd = ['hive', '-f', fname] if self.hive_cli_params: hive_params_list = self.hive_cli_params.split() hive_cmd.extend(hive_params_list) sp = subprocess.Popen(hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir) all_err = '' self.sp = sp stdout = "" for line in iter(sp.stdout.readline, ''): stdout += line logging.info(line.strip()) sp.wait() if sp.returncode: raise AirflowException(all_err) return stdout