def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == hook.get_uri() conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds') hook2 = conn2.get_hook() assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
def task_fail_slack_alert(context): """ Callback task that can be used in DAG to alert of failure task completion Args: context (dict): Context variable passed in from Airflow Returns: None: Calls the SlackWebhookOperator execute method internally """ conection = BaseHook.get_connection(SLACK_CONN_ID) slack_webhook_token = conection.password slack_msg = """*Status*: :x: Task Failed\n*Task*: {task}\n*Dag*: {dag}\n*Execution Time*: {exec_date}\n*Log Url*: {log_url}""".format( task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) if conection.extra_dejson.get('users'): slack_msg = slack_msg + '\n' + conection.extra_dejson.get('users') failed_alert = SlackWebhookOperator( task_id="slack_task", http_conn_id="slack", webhook_token=slack_webhook_token, message=slack_msg, username="******", link_names=True ) return failed_alert.execute(context=context)
def get_minio_object(pandas_read_callable, bucket, paths, pandas_read_callable_kwargs=None): s3_conn = BaseHook.get_connection(conn_id="s3") minio_client = Minio( s3_conn.extra_dejson["host"].split("://")[1], access_key=s3_conn.extra_dejson["aws_access_key_id"], secret_key=s3_conn.extra_dejson["aws_secret_access_key"], secure=False, ) if isinstance(paths, str): if paths.startswith("[") and paths.endswith("]"): paths = eval(paths) else: paths = [paths] if pandas_read_callable_kwargs is None: pandas_read_callable_kwargs = {} dfs = [] for path in paths: minio_object = minio_client.get_object(bucket_name=bucket, object_name=path) df = pandas_read_callable(minio_object, **pandas_read_callable_kwargs) dfs.append(df) return pd.concat(dfs)
def slack_success_notification(context): slack_webhook_token = BaseHook.get_connection("slack").password msg = """ :green_circle: Task Successful. *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) success_alert = SlackWebhookOperator( task_id="slack_success_notification", http_conn_id="slack", webhook_token=slack_webhook_token, message=msg, username="******", ) return success_alert.execute(context=context)
def slack_dag_failure_notification(context): slack_webhook_token = BaseHook.get_connection("slack").password icon_color = (":red_circle" if configuration.ENVIRONMENT.lower() == "production" else ":yellow_circle") msg = """ {icon_color}: Task Failed. *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( icon_color=icon_color, task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) failed_alert = SlackWebhookOperator( task_id="slack_failed_notification", http_conn_id="slack", webhook_token=slack_webhook_token, message=msg, username="******", ) return failed_alert.execute(context=context)
def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> None: api_key = None try: conn = BaseHook.get_connection(conn_id) api_key = conn.password except AirflowException: pass if api_key is None: warnings.warn( "Fetching Sendgrid credentials from environment variables will be deprecated in a future " "release. Please set credentials using a connection instead.", PendingDeprecationWarning, stacklevel=2, ) api_key = os.environ.get('SENDGRID_API_KEY') sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key) response = sendgrid_client.client.mail.send.post(request_body=mail_data) # 2xx status code. if 200 <= response.status_code < 300: log.info( 'Email with subject %s is successfully sent to recipients: %s', mail_data['subject'], mail_data['personalizations'], ) else: log.error( 'Failed to send out email with subject %s, status code: %s', mail_data['subject'], response.status_code, )
def send_mime_email( e_from: str, e_to: Union[str, List[str]], mime_msg: MIMEMultipart, conn_id: str = "smtp_default", dryrun: bool = False, ) -> None: """Send MIME email.""" smtp_host = conf.get('smtp', 'SMTP_HOST') smtp_port = conf.getint('smtp', 'SMTP_PORT') smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS') smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL') smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT') smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT') smtp_user = None smtp_password = None if conn_id is not None: try: from airflow.hooks.base import BaseHook airflow_conn = BaseHook.get_connection(conn_id) smtp_user = airflow_conn.login smtp_password = airflow_conn.password except AirflowException: pass if smtp_user is None or smtp_password is None: warnings.warn( "Fetching SMTP credentials from configuration variables will be deprecated in a future " "release. Please set credentials using a connection instead.", PendingDeprecationWarning, stacklevel=2, ) try: smtp_user = conf.get('smtp', 'SMTP_USER') smtp_password = conf.get('smtp', 'SMTP_PASSWORD') except AirflowConfigException: log.debug( "No user/password found for SMTP, so logging in with no authentication." ) if not dryrun: for attempt in range(1, smtp_retry_limit + 1): log.info("Email alerting: attempt %s", str(attempt)) try: smtp_conn = _get_smtp_connection(smtp_host, smtp_port, smtp_timeout, smtp_ssl) except smtplib.SMTPServerDisconnected: if attempt < smtp_retry_limit: continue raise if smtp_starttls: smtp_conn.starttls() if smtp_user and smtp_password: smtp_conn.login(smtp_user, smtp_password) log.info("Sent an alert email to %s", e_to) smtp_conn.sendmail(e_from, e_to, mime_msg.as_string()) smtp_conn.quit() break
def get_link( self, operator: "AbstractOperator", dttm: Optional[datetime] = None, *, ti_key: Optional["TaskInstanceKey"] = None, ) -> str: """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] ) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' if ti_key: qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key) else: assert dttm qds_command_id = XCom.get_one( key='qbol_cmd_id', dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm ) url = host + str(qds_command_id) if qds_command_id else '' return url
def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() engine = hook.get_sqlalchemy_engine() assert isinstance(engine, sqlalchemy.engine.Engine) assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == str( engine.url)
def get_link( self, operator, dttm=None, *, ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: run_id = XCom.get_value(key="run_id", ti_key=ti_key) else: assert dttm run_id = XCom.get_one( key="run_id", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, ) conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) subscription_id = conn.extra_dejson[ "extra__azure_data_factory__subscriptionId"] # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory # connection or passed directly to the operator. resource_group_name = operator.resource_group_name or conn.extra_dejson.get( "extra__azure_data_factory__resource_group_name") factory_name = operator.factory_name or conn.extra_dejson.get( "extra__azure_data_factory__factory_name") url = ( f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" f"?factory=/subscriptions/{subscription_id}/" f"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/" f"factories/{factory_name}") return url
def __init__(self, glue_crawler_name, aws_conn_id='aws_default', read_timeout=2000, *args, **kwargs): """ Trigger AWS Glue crawler :param glue_crawler_name: name of Glue crawler :param read_timeout: read time in order to wait Resource response before closing connection :param args: :param kwargs: """ super(GlueCrawlerOperator, self).__init__(*args, **kwargs) self.glue_crawler_name = glue_crawler_name if read_timeout is not None: print('check read_timeout') print(read_timeout) config = Config(read_timeout=read_timeout, retries={'max_attempts': 0}) else: config = Config(retries={'max_attempts': 0}) if aws_conn_id is not None: connection = BaseHook.get_connection(aws_conn_id) self.client = boto3.client( 'glue', aws_access_key_id=connection.login, aws_secret_access_key=connection.password, config=config, region_name=connection.extra_dejson.get('region_name')) else: raise AttributeError('Please pass a valid aws_connection_id')
def task_fail_slack_alert(context): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password if var_loader.get_git_user() != "cloud-bulldozer": print("Task Failed") return if context.get('task_instance').task_id != "final_status": print(context.get('task_instance').task_id, "Task failed") return slack_msg = """ :red_circle: DAG Failed {mem} *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( task=context.get('task_instance').task_id, dag=context.get('task_instance').dag_id, mem=alert_members(context), ti=context.get('task_instance'), exec_date=context.get('execution_date'), log_url=get_hyperlink(context), ) failed_alert = SlackWebhookOperator(task_id='slack_test', http_conn_id='slack', webhook_token=slack_webhook_token, message=slack_msg, username='******', link_names=True) return failed_alert.execute(context=context)
def create_connection(): try: db_con = BaseHook.get_connection('skyeng_db') return db_con.connect(db_con.host) except Exception as e: logging.error(e) return None
def get_airflow_gcp_credentials(conn_id: str) -> tuple: connection = BaseHook.get_connection(conn_id) conn_extra = connection.extra_dejson key_access_keyfile = "extra__google_cloud_platform__keyfile_dict" keyfile_str = conn_extra[key_access_keyfile] keyfile_dict = json.loads(keyfile_str) project_id = keyfile_dict["project_id"] return (project_id, keyfile_dict)
def kerberos_auth(conn_id: str) -> 'hdfs.ext.kerberos.KerberosClient': logging.info(f'Getting kerberos ticket ...') login = BaseHook.get_connection(conn_id).login password = BaseHook.get_connection(conn_id).password host = BaseHook.get_connection(conn_id).host port = BaseHook.get_connection(conn_id).port passwd = subprocess.Popen(('echo', password), stdout=subprocess.PIPE) subprocess.call(('kinit', login), stdin=passwd.stdout) session = requests.Session() session.verify = False return KerberosClient(f'https://{host}:{port}', mutual_auth="REQUIRED", max_concurrency=256, session=session)
def pg_params(conn_id: str = "postgres_default") -> str: """ Args: conn_id: database connection that is provided with default parameters returns: connection string with default params """ connection_uri = BaseHook.get_connection(conn_id).get_uri().split("?")[0] return f"{connection_uri} -X --set ON_ERROR_STOP=1"
def _get_hook(self): conn = BaseHook.get_connection(self.conn_id) hook = conn.get_hook(hook_params=self.hook_params) if not isinstance(hook, DbApiHook): raise AirflowException( f'The connection type is not supported by {self.__class__.__name__}. ' f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' ) return hook
def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) hook = conn.get_hook() if not callable(getattr(hook, 'get_pandas_df', None)): raise AirflowException( "This hook is not supported. The hook class must have get_pandas_df method." ) return hook
def connections_get(args): """Get a connection.""" try: conn = BaseHook.get_connection(args.conn_id) except AirflowNotFoundException: raise SystemExit("Connection not found.") AirflowConsole().print_as( data=[conn], output=args.output, mapper=_connection_mapper, )
def connection(self) -> Iterator[SwiftService]: """Setup the objectstore connection Yields: Iterator: An objectstore connection """ options = None if self.swift_conn_id != "swift_default": options = {} connection = BaseHook.get_connection(self.swift_conn_id) options["os_username"] = connection.login options["os_password"] = connection.password options["os_tenant_name"] = connection.host yield SwiftService(options=options)
def _download_citi_bike_data(ts_nodash, **_): citibike_conn = BaseHook.get_connection(conn_id="citibike") url = f"http://{citibike_conn.host}:{citibike_conn.port}/recent/minute/15" response = requests.get(url, auth=HTTPBasicAuth(citibike_conn.login, citibike_conn.password)) data = response.json() s3_hook = S3Hook(aws_conn_id="s3") s3_hook.load_string( string_data=json.dumps(data), key=f"raw/citibike/{ts_nodash}.json", bucket_name="datalake", )
def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) if Version(version) >= Version('2.3'): # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3. hook = conn.get_hook(hook_params=self.sql_hook_params) # ignore airflow compat check else: # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed # when "apache-airflow-providers-slack" will depend on Airflow >= 2.3. hook = _backported_get_hook(conn, hook_params=self.sql_hook_params) if not callable(getattr(hook, 'get_pandas_df', None)): raise AirflowException( "This hook is not supported. The hook class must have get_pandas_df method." ) return hook
def _hook(self): """Get DB Hook based on connection type""" self.log.debug("Get connection for %s", self.conn_id) conn = BaseHook.get_connection(self.conn_id) hook = conn.get_hook(hook_params=self.hook_params) if not isinstance(hook, DbApiHook): raise AirflowException( f'The connection type is not supported by {self.__class__.__name__}. ' f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' ) if self.database: hook.schema = self.database return hook
def _get_hook(self): conn = BaseHook.get_connection(self.conn_id) if Version(version) >= Version('2.3'): # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3. hook = conn.get_hook( hook_params=self.hook_params) # ignore airflow compat check else: # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed # when "apache-airflow-providers-common-sql" will depend on Airflow >= 2.3. hook = _backported_get_hook(conn, hook_params=self.hook_params) if not isinstance(hook, DbApiHook): raise AirflowException( f'The connection type is not supported by {self.__class__.__name__}. ' f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' ) return hook
def _get_hook(self): self.log.debug("Get connection for %s", self.conn_id) conn = BaseHook.get_connection(self.conn_id) if conn.conn_type not in ALLOWED_CONN_TYPE: raise AirflowException( "The connection type is not supported by BranchSQLOperator.\ Supported connection types: {}".format( list(ALLOWED_CONN_TYPE))) if not self._hook: self._hook = conn.get_hook() if self.database: self._hook.schema = self.database return self._hook
def execute(self, context: 'Context'): self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, default_gcp_project_id=self.gcp_connection.extra_dejson.get( 'extra__google_cloud_platform__project'), ) hook.validate_ssl_certs() connection = hook.create_connection() hook.validate_socket_path_length() database_hook = hook.get_database_hook(connection=connection) try: self._execute_query(hook, database_hook) finally: hook.cleanup_database_hook()
def poke(self, context: dict) -> bool: conn = BaseHook.get_connection(self.qubole_conn_id) Qubole.configure(api_token=conn.password, api_url=conn.host) self.log.info('Poking: %s', self.data) status = False try: status = self.sensor_class.check(self.data) # type: ignore[attr-defined] except Exception as e: self.log.exception(e) status = False self.log.info('Status of this Poke: %s', status) return status
def dag_success_slack_alert(context): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password slack_msg = """ :large_green_circle: DAG Succeeded. *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( dag=context.get('task_instance').dag_id, ti=context.get('task_instance'), exec_date=context.get('execution_date'), log_url=context.get('task_instance').log_url, ) failed_alert = SlackWebhookOperator(task_id='slack_test', http_conn_id='slack', webhook_token=slack_webhook_token, message=slack_msg, username='******') return failed_alert.execute(context=context)
def __init__(self, lambda_function_name, airflow_context_to_lambda_payload=None, additional_payload=None, aws_conn_id='aws_default', read_timeout=2000, *args, **kwargs): """ Trigger AWS Lambda function :param lambda_function_name: name of Lambda function :param airflow_context_to_lambda_payload: function extracting fields from Airflow context to Lambda payload :param additional_payload: additional parameters for Lambda payload :param aws_conn_id: aws connection id in order to call Lambda function :param read_timeout: read time in order to wait Resource response before closing connection :param args: :param kwargs: """ super(ExecuteLambdaOperator, self).__init__(*args, **kwargs) if additional_payload is None: additional_payload = {} self.airflow_context_to_lambda_payload = airflow_context_to_lambda_payload self.additional_payload = additional_payload self.lambda_function_name = lambda_function_name if read_timeout is not None: print('check read_timeout') print(read_timeout) config = Config(read_timeout=read_timeout, retries={'max_attempts': 0}) else: config = Config(retries={'max_attempts': 0}) if aws_conn_id is not None: connection = BaseHook.get_connection(aws_conn_id) self.lambda_client = boto3.client( 'lambda', aws_access_key_id=connection.login, aws_secret_access_key=connection.password, config=config, region_name=connection.extra_dejson.get('region_name')) else: raise AttributeError('Please pass a valid aws_connection_id')
def get_link(self, operator: BaseOperator, dttm: datetime) -> str: """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ ti = TaskInstance(task=operator, execution_date=dttm) conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] ) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id') url = host + str(qds_command_id) if qds_command_id else '' return url