def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == hook.get_uri() conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds') hook2 = conn2.get_hook() assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
def execute(self, context: Context): source_hook = BaseHook.get_hook(self.source_conn_id) destination_hook = BaseHook.get_hook(self.destination_conn_id) self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) get_records = getattr(source_hook, 'get_records', None) if not callable(get_records): raise RuntimeError( f"Hook for connection {self.source_conn_id!r} " f"({type(source_hook).__name__}) has no `get_records` method") else: results = get_records(self.sql) if self.preoperator: run = getattr(destination_hook, 'run', None) if not callable(run): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " f"({type(destination_hook).__name__}) has no `run` method") self.log.info("Running preoperator") self.log.info(self.preoperator) run(self.preoperator) insert_rows = getattr(destination_hook, 'insert_rows', None) if not callable(insert_rows): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " f"({type(destination_hook).__name__}) has no `insert_rows` method" ) self.log.info("Inserting rows into %s", self.destination_conn_id) insert_rows(table=self.destination_table, rows=results, **self.insert_args)
def task_fail_slack_alert(context): """ Callback task that can be used in DAG to alert of failure task completion Args: context (dict): Context variable passed in from Airflow Returns: None: Calls the SlackWebhookOperator execute method internally """ conection = BaseHook.get_connection(SLACK_CONN_ID) slack_webhook_token = conection.password slack_msg = """*Status*: :x: Task Failed\n*Task*: {task}\n*Dag*: {dag}\n*Execution Time*: {exec_date}\n*Log Url*: {log_url}""".format( task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) if conection.extra_dejson.get('users'): slack_msg = slack_msg + '\n' + conection.extra_dejson.get('users') failed_alert = SlackWebhookOperator( task_id="slack_task", http_conn_id="slack", webhook_token=slack_webhook_token, message=slack_msg, username="******", link_names=True ) return failed_alert.execute(context=context)
def get_minio_object(pandas_read_callable, bucket, paths, pandas_read_callable_kwargs=None): s3_conn = BaseHook.get_connection(conn_id="s3") minio_client = Minio( s3_conn.extra_dejson["host"].split("://")[1], access_key=s3_conn.extra_dejson["aws_access_key_id"], secret_key=s3_conn.extra_dejson["aws_secret_access_key"], secure=False, ) if isinstance(paths, str): if paths.startswith("[") and paths.endswith("]"): paths = eval(paths) else: paths = [paths] if pandas_read_callable_kwargs is None: pandas_read_callable_kwargs = {} dfs = [] for path in paths: minio_object = minio_client.get_object(bucket_name=bucket, object_name=path) df = pandas_read_callable(minio_object, **pandas_read_callable_kwargs) dfs.append(df) return pd.concat(dfs)
def send_mime_email( e_from: str, e_to: Union[str, List[str]], mime_msg: MIMEMultipart, conn_id: str = "smtp_default", dryrun: bool = False, ) -> None: """Send MIME email.""" smtp_host = conf.get('smtp', 'SMTP_HOST') smtp_port = conf.getint('smtp', 'SMTP_PORT') smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS') smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL') smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT') smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT') smtp_user = None smtp_password = None if conn_id is not None: try: from airflow.hooks.base import BaseHook airflow_conn = BaseHook.get_connection(conn_id) smtp_user = airflow_conn.login smtp_password = airflow_conn.password except AirflowException: pass if smtp_user is None or smtp_password is None: warnings.warn( "Fetching SMTP credentials from configuration variables will be deprecated in a future " "release. Please set credentials using a connection instead.", PendingDeprecationWarning, stacklevel=2, ) try: smtp_user = conf.get('smtp', 'SMTP_USER') smtp_password = conf.get('smtp', 'SMTP_PASSWORD') except AirflowConfigException: log.debug( "No user/password found for SMTP, so logging in with no authentication." ) if not dryrun: for attempt in range(1, smtp_retry_limit + 1): log.info("Email alerting: attempt %s", str(attempt)) try: smtp_conn = _get_smtp_connection(smtp_host, smtp_port, smtp_timeout, smtp_ssl) except smtplib.SMTPServerDisconnected: if attempt < smtp_retry_limit: continue raise if smtp_starttls: smtp_conn.starttls() if smtp_user and smtp_password: smtp_conn.login(smtp_user, smtp_password) log.info("Sent an alert email to %s", e_to) smtp_conn.sendmail(e_from, e_to, mime_msg.as_string()) smtp_conn.quit() break
def get_link( self, operator, dttm=None, *, ti_key: Optional["TaskInstanceKey"] = None, ) -> str: if ti_key: run_id = XCom.get_value(key="run_id", ti_key=ti_key) else: assert dttm run_id = XCom.get_one( key="run_id", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, ) conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) subscription_id = conn.extra_dejson[ "extra__azure_data_factory__subscriptionId"] # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory # connection or passed directly to the operator. resource_group_name = operator.resource_group_name or conn.extra_dejson.get( "extra__azure_data_factory__resource_group_name") factory_name = operator.factory_name or conn.extra_dejson.get( "extra__azure_data_factory__factory_name") url = ( f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" f"?factory=/subscriptions/{subscription_id}/" f"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/" f"factories/{factory_name}") return url
def slack_dag_failure_notification(context): slack_webhook_token = BaseHook.get_connection("slack").password icon_color = (":red_circle" if configuration.ENVIRONMENT.lower() == "production" else ":yellow_circle") msg = """ {icon_color}: Task Failed. *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( icon_color=icon_color, task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) failed_alert = SlackWebhookOperator( task_id="slack_failed_notification", http_conn_id="slack", webhook_token=slack_webhook_token, message=msg, username="******", ) return failed_alert.execute(context=context)
def mean_fare_per_class(titanic_df_json_str: dict): """ # Mean_fare_per_class task Takes the str of json data from XCOM, converts it to Pandas dataframe and makes some df aggregations Then dataframe is converted to the list of tuples and sent to external local DB """ # преобразуем в pandas dataframe и изменяем группировками, агрегациями: titanic_df = pd.read_json(titanic_df_json_str['titanic_df_json_str']) df = titanic_df \ .groupby(['Pclass']) \ .agg({'Fare': 'mean'}) \ .reset_index() # создаем кастом хук, коннектшн берем из предварительно созданного в UI: pg_hook = BaseHook.get_hook('postgres_default') # имя тааблицы в локальной БД предварительно задано в UI в Variables. Извлекаем: pg_table_name = Variable.get('mean_fares_table_name') # перемалываем датафрейм в список кортежей, приводим типы к стандартным (int и float): pg_rows = list(df.to_records(index=False)) pg_rows_conv = [(int(t[0]), float(t[1])) for t in pg_rows] # извлекаем названия полей(колонок) датафрейма: pg_columns = list(df.columns) # отправляем данные в локальную БД: pg_hook.insert_rows(table=pg_table_name, rows=pg_rows_conv, target_fields=pg_columns, commit_every=0, replace=False)
def get_link( self, operator: "AbstractOperator", dttm: Optional[datetime] = None, *, ti_key: Optional["TaskInstanceKey"] = None, ) -> str: """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] ) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' if ti_key: qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key) else: assert dttm qds_command_id = XCom.get_one( key='qbol_cmd_id', dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm ) url = host + str(qds_command_id) if qds_command_id else '' return url
def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() engine = hook.get_sqlalchemy_engine() assert isinstance(engine, sqlalchemy.engine.Engine) assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == str( engine.url)
def execute(self, context): source_hook = BaseHook.get_hook(self.source_conn_id) self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) results = source_hook.get_records(self.sql) destination_hook = BaseHook.get_hook(self.destination_conn_id) if self.preoperator: self.log.info("Running preoperator") self.log.info(self.preoperator) destination_hook.run(self.preoperator) self.log.info("Inserting rows into %s", self.destination_conn_id) destination_hook.insert_rows(table=self.destination_table, rows=results)
def __init__(self, glue_crawler_name, aws_conn_id='aws_default', read_timeout=2000, *args, **kwargs): """ Trigger AWS Glue crawler :param glue_crawler_name: name of Glue crawler :param read_timeout: read time in order to wait Resource response before closing connection :param args: :param kwargs: """ super(GlueCrawlerOperator, self).__init__(*args, **kwargs) self.glue_crawler_name = glue_crawler_name if read_timeout is not None: print('check read_timeout') print(read_timeout) config = Config(read_timeout=read_timeout, retries={'max_attempts': 0}) else: config = Config(retries={'max_attempts': 0}) if aws_conn_id is not None: connection = BaseHook.get_connection(aws_conn_id) self.client = boto3.client( 'glue', aws_access_key_id=connection.login, aws_secret_access_key=connection.password, config=config, region_name=connection.extra_dejson.get('region_name')) else: raise AttributeError('Please pass a valid aws_connection_id')
def test_check_operators(self): conn_id = "sqlite_default" captain_hook = BaseHook.get_hook(conn_id=conn_id) # quite funny :D captain_hook.run("CREATE TABLE operator_test_table (a, b)") captain_hook.run("insert into operator_test_table values (1,2)") self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op = CheckOperator(task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) op = ValueCheckOperator( task_id='value_check', pass_value=95, tolerance=0.1, conn_id=conn_id, sql="SELECT 100", dag=self.dag, ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) captain_hook.run("drop table operator_test_table")
def slack_success_notification(context): slack_webhook_token = BaseHook.get_connection("slack").password msg = """ :green_circle: Task Successful. *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( task=context.get("task_instance").task_id, dag=context.get("task_instance").dag_id, ti=context.get("task_instance"), exec_date=context.get("execution_date"), log_url=context.get("task_instance").log_url, ) success_alert = SlackWebhookOperator( task_id="slack_success_notification", http_conn_id="slack", webhook_token=slack_webhook_token, message=msg, username="******", ) return success_alert.execute(context=context)
def create_connection(): try: db_con = BaseHook.get_connection('skyeng_db') return db_con.connect(db_con.host) except Exception as e: logging.error(e) return None
def _post_sendgrid_mail(mail_data: Dict, conn_id: str = "sendgrid_default") -> None: api_key = None try: conn = BaseHook.get_connection(conn_id) api_key = conn.password except AirflowException: pass if api_key is None: warnings.warn( "Fetching Sendgrid credentials from environment variables will be deprecated in a future " "release. Please set credentials using a connection instead.", PendingDeprecationWarning, stacklevel=2, ) api_key = os.environ.get('SENDGRID_API_KEY') sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key) response = sendgrid_client.client.mail.send.post(request_body=mail_data) # 2xx status code. if 200 <= response.status_code < 300: log.info( 'Email with subject %s is successfully sent to recipients: %s', mail_data['subject'], mail_data['personalizations'], ) else: log.error( 'Failed to send out email with subject %s, status code: %s', mail_data['subject'], response.status_code, )
def task_fail_slack_alert(context): slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password if var_loader.get_git_user() != "cloud-bulldozer": print("Task Failed") return if context.get('task_instance').task_id != "final_status": print(context.get('task_instance').task_id, "Task failed") return slack_msg = """ :red_circle: DAG Failed {mem} *Task*: {task} *Dag*: {dag} *Execution Time*: {exec_date} *Log Url*: {log_url} """.format( task=context.get('task_instance').task_id, dag=context.get('task_instance').dag_id, mem=alert_members(context), ti=context.get('task_instance'), exec_date=context.get('execution_date'), log_url=get_hyperlink(context), ) failed_alert = SlackWebhookOperator(task_id='slack_test', http_conn_id='slack', webhook_token=slack_webhook_token, message=slack_msg, username='******', link_names=True) return failed_alert.execute(context=context)
def get_db_hook(self): """ Get the database hook for the connection. :return: the database hook object. :rtype: DbApiHook """ return BaseHook.get_hook(conn_id=self.conn_id)
def kerberos_auth(conn_id: str) -> 'hdfs.ext.kerberos.KerberosClient': logging.info(f'Getting kerberos ticket ...') login = BaseHook.get_connection(conn_id).login password = BaseHook.get_connection(conn_id).password host = BaseHook.get_connection(conn_id).host port = BaseHook.get_connection(conn_id).port passwd = subprocess.Popen(('echo', password), stdout=subprocess.PIPE) subprocess.call(('kinit', login), stdin=passwd.stdout) session = requests.Session() session.verify = False return KerberosClient(f'https://{host}:{port}', mutual_auth="REQUIRED", max_concurrency=256, session=session)
def get_airflow_gcp_credentials(conn_id: str) -> tuple: connection = BaseHook.get_connection(conn_id) conn_extra = connection.extra_dejson key_access_keyfile = "extra__google_cloud_platform__keyfile_dict" keyfile_str = conn_extra[key_access_keyfile] keyfile_dict = json.loads(keyfile_str) project_id = keyfile_dict["project_id"] return (project_id, keyfile_dict)
def _get_hook(self): conn = BaseHook.get_connection(self.conn_id) hook = conn.get_hook(hook_params=self.hook_params) if not isinstance(hook, DbApiHook): raise AirflowException( f'The connection type is not supported by {self.__class__.__name__}. ' f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' ) return hook
def pg_params(conn_id: str = "postgres_default") -> str: """ Args: conn_id: database connection that is provided with default parameters returns: connection string with default params """ connection_uri = BaseHook.get_connection(conn_id).get_uri().split("?")[0] return f"{connection_uri} -X --set ON_ERROR_STOP=1"
def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) hook = conn.get_hook() if not callable(getattr(hook, 'get_pandas_df', None)): raise AirflowException( "This hook is not supported. The hook class must have get_pandas_df method." ) return hook
def connections_get(args): """Get a connection.""" try: conn = BaseHook.get_connection(args.conn_id) except AirflowNotFoundException: raise SystemExit("Connection not found.") AirflowConsole().print_as( data=[conn], output=args.output, mapper=_connection_mapper, )
def _get_hook(self) -> DbApiHook: self.log.debug("Get connection for %s", self.sql_conn_id) conn = BaseHook.get_connection(self.sql_conn_id) if Version(version) >= Version('2.3'): # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3. hook = conn.get_hook(hook_params=self.sql_hook_params) # ignore airflow compat check else: # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed # when "apache-airflow-providers-slack" will depend on Airflow >= 2.3. hook = _backported_get_hook(conn, hook_params=self.sql_hook_params) if not callable(getattr(hook, 'get_pandas_df', None)): raise AirflowException( "This hook is not supported. The hook class must have get_pandas_df method." ) return hook
def connection(self) -> Iterator[SwiftService]: """Setup the objectstore connection Yields: Iterator: An objectstore connection """ options = None if self.swift_conn_id != "swift_default": options = {} connection = BaseHook.get_connection(self.swift_conn_id) options["os_username"] = connection.login options["os_password"] = connection.password options["os_tenant_name"] = connection.host yield SwiftService(options=options)
def _download_citi_bike_data(ts_nodash, **_): citibike_conn = BaseHook.get_connection(conn_id="citibike") url = f"http://{citibike_conn.host}:{citibike_conn.port}/recent/minute/15" response = requests.get(url, auth=HTTPBasicAuth(citibike_conn.login, citibike_conn.password)) data = response.json() s3_hook = S3Hook(aws_conn_id="s3") s3_hook.load_string( string_data=json.dumps(data), key=f"raw/citibike/{ts_nodash}.json", bucket_name="datalake", )
def _hook(self): """Get DB Hook based on connection type""" self.log.debug("Get connection for %s", self.conn_id) conn = BaseHook.get_connection(self.conn_id) hook = conn.get_hook(hook_params=self.hook_params) if not isinstance(hook, DbApiHook): raise AirflowException( f'The connection type is not supported by {self.__class__.__name__}. ' f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}' ) if self.database: hook.schema = self.database return hook
def _get_hook(self): self.log.debug("Get connection for %s", self.conn_id) conn = BaseHook.get_connection(self.conn_id) if conn.conn_type not in ALLOWED_CONN_TYPE: raise AirflowException( "The connection type is not supported by BranchSQLOperator.\ Supported connection types: {}".format( list(ALLOWED_CONN_TYPE))) if not self._hook: self._hook = conn.get_hook() if self.database: self._hook.schema = self.database return self._hook
def execute(self, context: 'Context'): self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id) hook = CloudSQLDatabaseHook( gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id, gcp_conn_id=self.gcp_conn_id, default_gcp_project_id=self.gcp_connection.extra_dejson.get( 'extra__google_cloud_platform__project'), ) hook.validate_ssl_certs() connection = hook.create_connection() hook.validate_socket_path_length() database_hook = hook.get_database_hook(connection=connection) try: self._execute_query(hook, database_hook) finally: hook.cleanup_database_hook()