def execute(self, context): # Specifying a service account file allows the user to using non default # authentication for creating a Kubernetes Pod. This is done by setting the # environment variable `GOOGLE_APPLICATION_CREDENTIALS` that gcloud looks at. key_file = None # If gcp_conn_id is not specified gcloud will use the default # service account credentials. if self.gcp_conn_id: from airflow.hooks.base_hook import BaseHook # extras is a deserialized json object extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson # key_file only gets set if a json file is created from a JSON string in # the web ui, else none key_file = self._set_env_from_extras(extras=extras) # Write config to a temp file and set the environment variable to point to it. # This is to avoid race conditions of reading/writing a single file with tempfile.NamedTemporaryFile() as conf_file: os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name # Attempt to get/update credentials # We call gcloud directly instead of using google-cloud-python api # because there is no way to write kubernetes config to a file, which is # required by KubernetesPodOperator. # The gcloud command looks at the env variable `KUBECONFIG` for where to save # the kubernetes config file. subprocess.check_call( ["gcloud", "container", "clusters", "get-credentials", self.cluster_name, "--zone", self.location, "--project", self.project_id]) # Since the key file is of type mkstemp() closing the file will delete it from # the file system so it cannot be accessed after we don't need it anymore if key_file: key_file.close() # Tell `KubernetesPodOperator` where the config file is located self.config_file = os.environ[KUBE_CONFIG_ENV_VAR] return super().execute(context)
def execute(self, context): # If gcp_conn_id is not specified gcloud will use the default # service account credentials. if self.gcp_conn_id: from airflow.hooks.base_hook import BaseHook # extras is a deserialized json object extras = BaseHook.get_connection(self.gcp_conn_id).extra_dejson self._set_env_from_extras(extras=extras) # Write config to a temp file and set the environment variable to point to it. # This is to avoid race conditions of reading/writing a single file with tempfile.NamedTemporaryFile() as conf_file: os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name # Attempt to get/update credentials # We call gcloud directly instead of using google-cloud-python api # because there is no way to write kubernetes config to a file, which is # required by KubernetesPodOperator. # The gcloud command looks at the env variable `KUBECONFIG` for where to save # the kubernetes config file. subprocess.check_call( ["gcloud", "container", "clusters", "get-credentials", self.cluster_name, "--zone", self.location, "--project", self.project_id]) # Tell `KubernetesPodOperator` where the config file is located self.config_file = os.environ[KUBE_CONFIG_ENV_VAR] super(GKEPodOperator, self).execute(context)
def poke(self, context): hook = BaseHook.get_connection(self.conn_id).get_hook() self.log.info('Poking: %s', self.sql) records = hook.get_records(self.sql) if not records: return False return str(records[0][0]) not in ('0', '')
def _get_project_id(): """Get project ID from default GCP connection.""" extras = BaseHook.get_connection('google_cloud_default').extra_dejson key = 'extra__google_cloud_platform__project' if key in extras: project_id = extras[key] else: raise ('Must configure project_id in google_cloud_default ' 'connection from Airflow Console') return project_id
def poke(self, context): hook = BaseHook.get_connection(self.conn_id).get_hook() logging.info('Poking: ' + self.sql) records = hook.get_records(self.sql) if not records: return False else: if str(records[0][0]) in ('0', '',): return False else: return True print(records[0][0])
def pre_execute(self, context): dbnd_conn_config = BaseHook.get_connection(DATABAND_AIRFLOW_CONN_ID) json_config = dbnd_conn_config.extra_dejson dbnd_config = self.to_env( self.flatten(json_config, parent_key="DBND", sep="__")) self.env = os.environ.copy() self.env.update(dbnd_config) self.env.update({ "DBND__LOG__LEVEL": LOG_LEVEL, "DBND__AIRFLOW_MONITOR__SQL_ALCHEMY_CONN": settings.SQL_ALCHEMY_CONN, "DBND__AIRFLOW_MONITOR__LOCAL_DAG_FOLDER": settings.DAGS_FOLDER, "DBND__AIRFLOW_MONITOR__FETCHER": "db", })
def __init__(self, engine: Engine): # Bind engine to metadata only if it is not bind already if self.metadata.bind is None: # Skip engine building if it is already provided if engine is None: connection = BaseHook.get_connection( conn_id=settings.DB_CONN_ID) engine = db_utils.create_db_engine( login=connection.login, password=connection.password, host=connection.host, schema=connection.schema, conn_type=connection.conn_type) self.metadata.bind = engine
def read_mock_data(): print("Fetching mock response text...") connection = BaseHook.get_connection("MOCKSERVER") conn_type = connection.conn_type host = connection.host port = connection.port mockserver_conn = str(conn_type) + "://" + str(host) + ":" + str( port) + "/" + "mock-response" print("connecting to " + mockserver_conn) r = requests.get(mockserver_conn) StringData = StringIO(r.text) df = pd.read_csv(StringData, sep=",", parse_dates=["timestamp"]) df["bikes_available"] = df["bikes_available"].astype("int32") df["docks_available"] = df["docks_available"].astype("int32") print(df) return df
def poke(self, context): conn = BaseHook.get_connection(self.conn_id) allowed_conn_type = {'google_cloud_platform', 'jdbc', 'mssql', 'mysql', 'oracle', 'postgres', 'presto', 'sqlite', 'vertica'} if conn.conn_type not in allowed_conn_type: raise AirflowException("The connection type is not supported by SqlSensor. " + "Supported connection types: {}".format(list(allowed_conn_type))) hook = conn.get_hook() self.log.info('Poking: %s (with parameters %s)', self.sql, self.parameters) records = hook.get_records(self.sql, self.parameters) if not records: return False return str(records[0][0]) not in ('0', '')
def tell_slack_failed(context): #return SlackNequiOperator.tell_slack_failed(context) webhook = BaseHook.get_connection('Slack').password message = ':red_circle: AIRFLOW TASK FAILURE TIPS:\n' \ 'DAG: {}\n' \ 'TASKS: {}\n' \ 'Reason: {}\n' \ .format(context['task_instance'].dag_id, context['task_instance'].task_id, context['exception']) alterHook = SlackWebhookOperator(task_id='integrate_slack', http_conn_id='Slack', webhook_token=webhook, message=message, username='******') return alterHook.execute(context=context)
def poke(self, context): conn = BaseHook.get_connection(self.qubole_conn_id) Qubole.configure(api_token=conn.password, api_url=conn.host) this.log.info('Poking: %s', self.data) status = False try: status = self.sensor_class.check(self.data) except Exception as e: logging.exception(e) status = False this.log.info('Status of this Poke: %s', status) return status
def poke(self, context): hook = BaseHook.get_connection(self.conn_id).get_hook() logging.info('Poking: ' + self.sql) records = hook.get_records(self.sql) if not records: return False else: if str(records[0][0]) in ( '0', '', ): return False else: return True print(records[0][0])
def _get_hook(self): self.log.debug("Get connection for %s", self.conn_id) conn = BaseHook.get_connection(self.conn_id) if conn.conn_type not in ALLOWED_CONN_TYPE: raise AirflowException( "The connection type is not supported by BranchSQLOperator.\ Supported connection types: {}".format( list(ALLOWED_CONN_TYPE))) if not self._hook: self._hook = conn.get_hook() if self.database: self._hook.schema = self.database return self._hook
def execute(self, context): self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, default_gcp_project_id=self.gcp_connection.extra_dejson.get( 'extra__google_cloud_platform__project'), ) hook.validate_ssl_certs() connection = hook.create_connection() hook.validate_socket_path_length() database_hook = hook.get_database_hook(connection=connection) try: self._execute_query(hook, database_hook) finally: hook.cleanup_database_hook()
def update_la_cases_data(**kwargs): """ The actual python callable that Airflow schedules. """ # Getting data from google sheet filename = kwargs.get("filename") workbook = kwargs.get("workbook") sheet_name = kwargs.get("sheet_name") get_data(filename, workbook, sheet_name) # Updating ArcGis arcconnection = BaseHook.get_connection("arcgis") arcuser = arcconnection.login arcpassword = arcconnection.password arcfeatureid = kwargs.get("arcfeatureid") update_arcgis(arcuser, arcpassword, arcfeatureid, filename)
def _get_auth_headers(self): """ Função auxiliar para construir os cabeçalhos para as requisições à API. Inclui o cabeçalho de autenticação """ conn_values = BaseHook.get_connection(self.conn_id) message = f"{conn_values.login}:{conn_values.password}" message_bytes = message.encode('ascii') base64_bytes = base64.b64encode(message_bytes) base64_message = base64_bytes.decode('ascii') headers = { "user-agent": "airflow-SEGES-ME", "authorization": f"Basic {base64_message}" } return headers
def dag_notify_success_slack_alert(dag): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password slack_msg = """ :white_check_mark: DAG Completed. *Dag*: {dag} """.format(dag=dag.dag_id, ) success_alert = SlackWebhookOperator( task_id="notify_dag_success", http_conn_id="slack", webhook_token=slack_webhook_token, message=slack_msg, username="******", trigger_rule="all_success", dag=dag, ) return success_alert
def execute(self, context): hook = BaseHook.get_connection(self.conn_id).get_hook() # We don't create the table here. We could but if the signal # table is missing there's probably something more disastrously # wrong that should be looked into. # # For the record, it was created: # CREATE SCHEMA airflow; # CREATE TABLE airflow.signal (schema_name TEXT, table_name TEXT, partition_id TEXT, status VARCHAR(256)); sql = """ INSERT INTO airflow.signal (schema_name, table_name, partition_id, status) VALUES ('{0}', '{1}', '{2}', '{3}') """.format(self.schema, self.table, self.partition_id, 'done') hook.run(sql)
def task_fail_slack_alert(context): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password # print this task_msg and tag these users task_msg = """The Task {task} failed. {slack_name} please fix it """.format( task=context.get('task_instance').task_id, slack_name = list_names,) # this adds the error log url at the end of the msg slack_msg = task_msg + """ (<{log_url}|log>)""".format( log_url=context.get('task_instance').log_url,) failed_alert = SlackWebhookOperator( task_id='slack_test', http_conn_id='slack', webhook_token=slack_webhook_token, message=slack_msg, username='******', ) return failed_alert.execute(context=context)
def poke(self, context): self.log.info('Poking for file oss://%s/%s', self.bucket_name, self.bucket_key) conn = BaseHook.get_connection(self.oss_conn_id) endpoint = conn.host access_key_id = conn.login access_key_secret = conn.password auth = oss2.Auth(access_key_id, access_key_secret) bucket = oss2.Bucket(auth, endpoint, self.bucket_name) try: objects = bucket.list_objects(self.bucket_key) return bool(len(objects.object_list)) except Exception: e = sys.exc_info() self.log.debug("Caught an exception !: %s", str(e)) return False
def update_cluster_connection(**kwargs): """ Updates the ssh_default connection of the EMR cluster after creating a new one using setup_cloud_environment :param kwargs: :return: """ ti = kwargs['ti'] emr_dns = ti.xcom_pull(task_ids='iac_create_emr_cluster') ssh_conn = BaseHook.get_connection('ssh_default') ssh_conn.host = emr_dns session = settings.Session() # get the session session.add(ssh_conn) session.commit()
def task_fail_slack_alert(context): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password task_msg = 'The {task} in Refreshing the WYS Open Data failed, {slack_name} go fix it meow :meow_headache: '.format( task=context.get('task_instance').task_id, slack_name=list_names, ) slack_msg = task_msg + """(<{log_url}|log>)""".format( log_url=context.get('task_instance').log_url, ) failed_alert = SlackWebhookOperator( task_id='slack_test', http_conn_id='slack', webhook_token=slack_webhook_token, message=slack_msg, username='******', ) return failed_alert.execute(context=context)
def get_link(self, operator, dttm): """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id']) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') url = host + str(qds_command_id) if qds_command_id else '' return url
def extract_weather_forecast(lon, lat, **kwargs): logging.info("Extracting Weather Forecast...") URL = "https://api.darksky.net/forecast" try: api_key = BaseHook.get_connection("dark_sky").password url = f"{URL}/{api_key}/{lon},{lat}" r = requests.get(url=url) r.raise_for_status() data = r.json() kwargs["ti"].xcom_push(key="raw_data", value=data) except BaseException as e: logging.error("Failed to extract forecast from API!") raise e
def poke(self, context): hook = BaseHook.get_connection(self.conn_id).get_hook() sql = """ SELECT COUNT(1) FROM airflow.signal WHERE schema_name = '{0}' AND table_name = '{1}' AND partition_id = '{2}' AND status = 'done' """.format(self.schema, self.table, self.partition_id) logging.info('Poking: ' + sql) record = hook.get_first(sql) signal_present = record[0] > 0 if not signal_present: raise AirflowException('Not present -- retry. If this is a test, then run the dependent job to fix the signal table issue.') return True
def poke(self, context): global this # apache/incubator-airflow/pull/3297#issuecomment-385988083 conn = BaseHook.get_connection(self.qubole_conn_id) Qubole.configure(api_token=conn.password, api_url=conn.host) this.log.info('Poking: %s', self.data) status = False try: status = self.sensor_class.check(self.data) except Exception as e: this.log.exception(e) status = False this.log.info('Status of this Poke: %s', status) return status
def run_query(replication_key_value=None, **kwargs): """ Run a query against stripe :param stripe_object: name of the Stripe object :param replication_key_value: Stripe replicaton key value """ stripe.api_key = BaseHook.get_connection("stripe").password stripe_endpoint = getattr(stripe, 'Event') params = {} params['created[gte]'] = replication_key_value stripe_response = getattr(stripe_endpoint, 'list')(**params) for res in stripe_response.auto_paging_iter(): json_result = json.loads(json.dumps(res)) yield json_result
def read_hdfs_file(): connection = BaseHook.get_connection("webhdfs_default") conn_type = connection.conn_type host = connection.host port = connection.port hdfs_conn = str(conn_type) + "://" + str(host) + ":" + str(port) print("connecting to " + hdfs_conn) myclient = client.Client(hdfs_conn) pathname = os.path.join(tempfile.mkdtemp(), 'hdfs_stationMart_data') print("downloading files from HDFS to " + pathname) myclient.download("/tw/stationMart/data/", pathname, True) print("downloading complete") dataset = pq.ParquetDataset(pathname) table = dataset.read() df = table.to_pandas() print(df) return df
def poke(self, context): conn = BaseHook.get_connection(self.qubole_conn_id) Qubole.configure(api_token=conn.password, api_url=conn.host) self.log.info('Poking: %s', self.data) status = False try: status = self.sensor_class.check(self.data) # pylint: disable=no-member except Exception as e: # pylint: disable=broad-except self.log.exception(e) status = False self.log.info('Status of this Poke: %s', status) return status
def execute(self, context): redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") s3_path = "s3://{}/{}".format(self.s3_bucket, self.rendered_key) self.log.info(s3_path) credentials = BaseHook.get_connection(self.aws_credentials_id) formatted_sql = StageToRedshiftOperator.copy_sql.format( self.table, s3_path, self.json, credentials.login, credentials.password) self.log.info(formatted_sql) redshift.run(formatted_sql)
def get_extra_from_conn(conn_id): """ Obtain extra fields from airflow connection. Parameters ---------- conn_id : str Airflow Connection ID Returns ------- dict extra kwargs """ hook = BaseHook(conn_id) conn = hook.get_connection(conn_id) return json.loads(conn.extra)
def _get_data_from_sql( self, sql, params={}, return_dict=True, ): '''In Oracle, params are passed to the execute() function as kwargs https://cx-oracle.readthedocs.io/en/latest/user_guide/bind.html Params are then referenced like so: SELECT * FROM table WHERE field1 = :val1 AND field2 = :val2 with a params dict structured like so: {'val1': 1, 'val2': 3} Also see here: https://stackoverflow.com/questions/35045879/cx-oracle-how-can-i-receive-each-row-as-a-dictionary to understand the workaround regarding return_dict ''' def makeDictFactory(cursor): columnNames = [d[0] for d in cursor.description] def createRow(*args): return dict(zip(columnNames, args)) return createRow connection = BaseHook.get_connection(self.source_conn_id) oracle_conn = cx_Oracle.connect(connection.login, connection.password, '{0}:{1}/{2}'.format( connection.host, connection.port, connection.schema, ), encoding='UTF-8') cursor = oracle_conn.cursor() if sql.strip()[-1:] == ';': # OracleSQL doesn't like semicolons sql = sql.strip()[:-1] self.log.info('Executing:\n{0}\n\nWith params:\n{1}'.format( sql, str(params), )) cursor.execute(sql, **params) if return_dict: cursor.rowfactory = makeDictFactory(cursor) data = cursor.fetchall() cursor.close() oracle_conn.close() return data
def get_extra_links(self, operator, dttm): """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id']) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' ti = TaskInstance(task=operator, execution_date=dttm) qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') url = host + str(qds_command_id) if qds_command_id else '' return url
def __init__(self, datadog_conn_id='datadog_default'): super().__init__() conn = BaseHook.get_connection(datadog_conn_id) self.api_key = conn.extra_dejson.get('api_key', None) self.app_key = conn.extra_dejson.get('app_key', None) self.source_type_name = conn.extra_dejson.get('source_type_name ', None) # If the host is populated, it will use that hostname instead # for all metric submissions self.host = conn.host if self.api_key is None: raise AirflowException('api_key must be specified in the ' 'Datadog connection details') self.log.info('Setting up api keys for Datadog') self.stats = None initialize(api_key=self.api_key, app_key=self.app_key)
def get_mssql_odbc_conn_str(conn_id: str): """ Cria uma string de conexão com banco SQL Server usando driver pyodbc. """ conn_values = BaseHook.get_connection(conn_id) driver = '{ODBC Driver 17 for SQL Server}' server = conn_values.host port = conn_values.port database = conn_values.schema user = conn_values.login password = conn_values.password mssql_conn = f"""Driver={driver};Server={server}, {port}; \ Database={database};Uid={user};Pwd={password};""" quoted_conn_str = urllib.parse.quote_plus(mssql_conn) return f'mssql+pyodbc:///?odbc_connect={quoted_conn_str}'
def get_redshift_uri(conn_id): """ Builds a (redshift-) jdbc-uri from a given airflow connection-id. """ hook = BaseHook(conn_id) conn = hook.get_connection(conn_id) if not conn.host: return "" else: uri = f"jdbc:redshift://{conn.host}:{conn.port}/{conn.schema}?user={conn.login}&password={conn.password}" extra = conn.extra_dejson params = [f"{k}={v}" for k, v in extra.items()] if params: params = "&".join(params) uri += f"?{params}" return uri
def extract_from_app_store(ds, **kwargs): if kwargs.get('test_mode'): # connection = Connection(host='docker.for.mac.localhost', port=9092) connection = Connection(host='0.0.0.0', port=9092) else: store_code = kwargs.get('store_id') connection = BaseHook.get_connection( "app_store_{store_code}".format(store_code=store_code)) app_store_client = AppStoreXClient(host=connection.host, port=connection.port, store_code=store_code) stats = app_store_client.get_stats(date_from=ds, date_to=ds) #need_put_data_into_external_storage return stats
def __init__(self, sql, autocommit=False, parameters=None, gcp_conn_id='google_cloud_default', gcp_cloudsql_conn_id='google_cloud_sql_default', *args, **kwargs): super(CloudSqlQueryOperator, self).__init__(*args, **kwargs) self.sql = sql self.gcp_conn_id = gcp_conn_id self.gcp_cloudsql_conn_id = gcp_cloudsql_conn_id self.autocommit = autocommit self.parameters = parameters self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) self.cloudsql_db_hook = CloudSqlDatabaseHook( gcp_cloudsql_conn_id=gcp_cloudsql_conn_id, default_gcp_project_id=self.gcp_connection.extra_dejson.get( 'extra__google_cloud_platform__project')) self.cloud_sql_proxy_runner = None self.database_hook = None