def test_dbapi_get_uri(self):
     conn = BaseHook.get_connection(conn_id='test_uri')
     hook = conn.get_hook()
     assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == hook.get_uri()
     conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
     hook2 = conn2.get_hook()
     assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
Exemple #2
0
def task_fail_slack_alert(context):
    """
    Callback task that can be used in DAG to alert of failure task completion
    Args:
        context (dict): Context variable passed in from Airflow
    Returns:
        None: Calls the SlackWebhookOperator execute method internally
    """
    conection = BaseHook.get_connection(SLACK_CONN_ID)
    slack_webhook_token = conection.password
    slack_msg = """*Status*: :x: Task Failed\n*Task*: {task}\n*Dag*: {dag}\n*Execution Time*: {exec_date}\n*Log Url*: {log_url}""".format(
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )
    if conection.extra_dejson.get('users'):
        slack_msg = slack_msg + '\n' + conection.extra_dejson.get('users')

    failed_alert = SlackWebhookOperator(
        task_id="slack_task",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=slack_msg,
        username="******",
        link_names=True
    )

    return failed_alert.execute(context=context)
Exemple #3
0
def get_minio_object(pandas_read_callable,
                     bucket,
                     paths,
                     pandas_read_callable_kwargs=None):
    s3_conn = BaseHook.get_connection(conn_id="s3")
    minio_client = Minio(
        s3_conn.extra_dejson["host"].split("://")[1],
        access_key=s3_conn.extra_dejson["aws_access_key_id"],
        secret_key=s3_conn.extra_dejson["aws_secret_access_key"],
        secure=False,
    )

    if isinstance(paths, str):
        if paths.startswith("[") and paths.endswith("]"):
            paths = eval(paths)
        else:
            paths = [paths]

    if pandas_read_callable_kwargs is None:
        pandas_read_callable_kwargs = {}

    dfs = []
    for path in paths:
        minio_object = minio_client.get_object(bucket_name=bucket,
                                               object_name=path)
        df = pandas_read_callable(minio_object, **pandas_read_callable_kwargs)
        dfs.append(df)
    return pd.concat(dfs)
Exemple #4
0
def slack_success_notification(context):
    slack_webhook_token = BaseHook.get_connection("slack").password

    msg = """
          :green_circle: Task Successful. 
          *Task*: {task}  
          *Dag*: {dag} 
          *Execution Time*: {exec_date}  
          *Log Url*: {log_url} 
          """.format(
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )

    success_alert = SlackWebhookOperator(
        task_id="slack_success_notification",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=msg,
        username="******",
    )

    return success_alert.execute(context=context)
Exemple #5
0
def slack_dag_failure_notification(context):
    slack_webhook_token = BaseHook.get_connection("slack").password
    icon_color = (":red_circle" if configuration.ENVIRONMENT.lower()
                  == "production" else ":yellow_circle")

    msg = """
          {icon_color}: Task Failed. 
          *Task*: {task}  
          *Dag*: {dag}
          *Execution Time*: {exec_date}  
          *Log Url*: {log_url} 
          """.format(
        icon_color=icon_color,
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )

    failed_alert = SlackWebhookOperator(
        task_id="slack_failed_notification",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=msg,
        username="******",
    )

    return failed_alert.execute(context=context)
Exemple #6
0
def _post_sendgrid_mail(mail_data: Dict,
                        conn_id: str = "sendgrid_default") -> None:
    api_key = None
    try:
        conn = BaseHook.get_connection(conn_id)
        api_key = conn.password
    except AirflowException:
        pass
    if api_key is None:
        warnings.warn(
            "Fetching Sendgrid credentials from environment variables will be deprecated in a future "
            "release. Please set credentials using a connection instead.",
            PendingDeprecationWarning,
            stacklevel=2,
        )
        api_key = os.environ.get('SENDGRID_API_KEY')
    sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key)
    response = sendgrid_client.client.mail.send.post(request_body=mail_data)
    # 2xx status code.
    if 200 <= response.status_code < 300:
        log.info(
            'Email with subject %s is successfully sent to recipients: %s',
            mail_data['subject'],
            mail_data['personalizations'],
        )
    else:
        log.error(
            'Failed to send out email with subject %s, status code: %s',
            mail_data['subject'],
            response.status_code,
        )
Exemple #7
0
def send_mime_email(
    e_from: str,
    e_to: Union[str, List[str]],
    mime_msg: MIMEMultipart,
    conn_id: str = "smtp_default",
    dryrun: bool = False,
) -> None:
    """Send MIME email."""
    smtp_host = conf.get('smtp', 'SMTP_HOST')
    smtp_port = conf.getint('smtp', 'SMTP_PORT')
    smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS')
    smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL')
    smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT')
    smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT')
    smtp_user = None
    smtp_password = None

    if conn_id is not None:
        try:
            from airflow.hooks.base import BaseHook

            airflow_conn = BaseHook.get_connection(conn_id)
            smtp_user = airflow_conn.login
            smtp_password = airflow_conn.password
        except AirflowException:
            pass
    if smtp_user is None or smtp_password is None:
        warnings.warn(
            "Fetching SMTP credentials from configuration variables will be deprecated in a future "
            "release. Please set credentials using a connection instead.",
            PendingDeprecationWarning,
            stacklevel=2,
        )
        try:
            smtp_user = conf.get('smtp', 'SMTP_USER')
            smtp_password = conf.get('smtp', 'SMTP_PASSWORD')
        except AirflowConfigException:
            log.debug(
                "No user/password found for SMTP, so logging in with no authentication."
            )

    if not dryrun:
        for attempt in range(1, smtp_retry_limit + 1):
            log.info("Email alerting: attempt %s", str(attempt))
            try:
                smtp_conn = _get_smtp_connection(smtp_host, smtp_port,
                                                 smtp_timeout, smtp_ssl)
            except smtplib.SMTPServerDisconnected:
                if attempt < smtp_retry_limit:
                    continue
                raise

            if smtp_starttls:
                smtp_conn.starttls()
            if smtp_user and smtp_password:
                smtp_conn.login(smtp_user, smtp_password)
            log.info("Sent an alert email to %s", e_to)
            smtp_conn.sendmail(e_from, e_to, mime_msg.as_string())
            smtp_conn.quit()
            break
Exemple #8
0
    def get_link(
        self,
        operator: "AbstractOperator",
        dttm: Optional[datetime] = None,
        *,
        ti_key: Optional["TaskInstanceKey"] = None,
    ) -> str:
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        conn = BaseHook.get_connection(
            getattr(operator, "qubole_conn_id", None)
            or operator.kwargs['qubole_conn_id']  # type: ignore[attr-defined]
        )
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='
        if ti_key:
            qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key)
        else:
            assert dttm
            qds_command_id = XCom.get_one(
                key='qbol_cmd_id', dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
            )
        url = host + str(qds_command_id) if qds_command_id else ''
        return url
Exemple #9
0
 def test_dbapi_get_sqlalchemy_engine(self):
     conn = BaseHook.get_connection(conn_id='test_uri')
     hook = conn.get_hook()
     engine = hook.get_sqlalchemy_engine()
     assert isinstance(engine, sqlalchemy.engine.Engine)
     assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == str(
         engine.url)
Exemple #10
0
    def get_link(
        self,
        operator,
        dttm=None,
        *,
        ti_key: Optional["TaskInstanceKey"] = None,
    ) -> str:
        if ti_key:
            run_id = XCom.get_value(key="run_id", ti_key=ti_key)
        else:
            assert dttm
            run_id = XCom.get_one(
                key="run_id",
                dag_id=operator.dag.dag_id,
                task_id=operator.task_id,
                execution_date=dttm,
            )

        conn = BaseHook.get_connection(operator.azure_data_factory_conn_id)
        subscription_id = conn.extra_dejson[
            "extra__azure_data_factory__subscriptionId"]
        # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory
        # connection or passed directly to the operator.
        resource_group_name = operator.resource_group_name or conn.extra_dejson.get(
            "extra__azure_data_factory__resource_group_name")
        factory_name = operator.factory_name or conn.extra_dejson.get(
            "extra__azure_data_factory__factory_name")
        url = (
            f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}"
            f"?factory=/subscriptions/{subscription_id}/"
            f"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/"
            f"factories/{factory_name}")

        return url
    def __init__(self,
                 glue_crawler_name,
                 aws_conn_id='aws_default',
                 read_timeout=2000,
                 *args,
                 **kwargs):
        """
        Trigger AWS Glue crawler
        :param glue_crawler_name: name of Glue crawler
        :param read_timeout: read time in order to wait Resource response before closing connection
        :param args:
        :param kwargs:
        """
        super(GlueCrawlerOperator, self).__init__(*args, **kwargs)

        self.glue_crawler_name = glue_crawler_name
        if read_timeout is not None:
            print('check read_timeout')
            print(read_timeout)
            config = Config(read_timeout=read_timeout,
                            retries={'max_attempts': 0})
        else:
            config = Config(retries={'max_attempts': 0})
        if aws_conn_id is not None:
            connection = BaseHook.get_connection(aws_conn_id)
            self.client = boto3.client(
                'glue',
                aws_access_key_id=connection.login,
                aws_secret_access_key=connection.password,
                config=config,
                region_name=connection.extra_dejson.get('region_name'))
        else:
            raise AttributeError('Please pass a valid aws_connection_id')
def task_fail_slack_alert(context):
    slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password
    if var_loader.get_git_user() != "cloud-bulldozer":
        print("Task Failed")
        return
    if context.get('task_instance').task_id != "final_status":
        print(context.get('task_instance').task_id, "Task failed")
        return

    slack_msg = """
            :red_circle: DAG Failed {mem} 
            *Task*: {task}  
            *Dag*: {dag} 
            *Execution Time*: {exec_date}  
            *Log Url*: {log_url} 
            """.format(
        task=context.get('task_instance').task_id,
        dag=context.get('task_instance').dag_id,
        mem=alert_members(context),
        ti=context.get('task_instance'),
        exec_date=context.get('execution_date'),
        log_url=get_hyperlink(context),
    )
    failed_alert = SlackWebhookOperator(task_id='slack_test',
                                        http_conn_id='slack',
                                        webhook_token=slack_webhook_token,
                                        message=slack_msg,
                                        username='******',
                                        link_names=True)
    return failed_alert.execute(context=context)
Exemple #13
0
def create_connection():
    try:
        db_con = BaseHook.get_connection('skyeng_db')
        return db_con.connect(db_con.host)
    except Exception as e:
        logging.error(e)
    return None
Exemple #14
0
def get_airflow_gcp_credentials(conn_id: str) -> tuple:
    connection = BaseHook.get_connection(conn_id)
    conn_extra = connection.extra_dejson
    key_access_keyfile = "extra__google_cloud_platform__keyfile_dict"
    keyfile_str = conn_extra[key_access_keyfile]
    keyfile_dict = json.loads(keyfile_str)
    project_id = keyfile_dict["project_id"]
    return (project_id, keyfile_dict)
Exemple #15
0
def kerberos_auth(conn_id: str) -> 'hdfs.ext.kerberos.KerberosClient':
    logging.info(f'Getting kerberos ticket ...')
    login = BaseHook.get_connection(conn_id).login
    password = BaseHook.get_connection(conn_id).password
    host = BaseHook.get_connection(conn_id).host
    port = BaseHook.get_connection(conn_id).port

    passwd = subprocess.Popen(('echo', password), stdout=subprocess.PIPE)
    subprocess.call(('kinit', login), stdin=passwd.stdout)

    session = requests.Session()
    session.verify = False

    return KerberosClient(f'https://{host}:{port}',
                          mutual_auth="REQUIRED",
                          max_concurrency=256,
                          session=session)
def pg_params(conn_id: str = "postgres_default") -> str:
    """
    Args:
        conn_id: database connection that is provided with default parameters
    returns:
        connection string with default params
    """
    connection_uri = BaseHook.get_connection(conn_id).get_uri().split("?")[0]
    return f"{connection_uri} -X --set ON_ERROR_STOP=1"
Exemple #17
0
 def _get_hook(self):
     conn = BaseHook.get_connection(self.conn_id)
     hook = conn.get_hook(hook_params=self.hook_params)
     if not isinstance(hook, DbApiHook):
         raise AirflowException(
             f'The connection type is not supported by {self.__class__.__name__}. '
             f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
         )
     return hook
Exemple #18
0
 def _get_hook(self) -> DbApiHook:
     self.log.debug("Get connection for %s", self.sql_conn_id)
     conn = BaseHook.get_connection(self.sql_conn_id)
     hook = conn.get_hook()
     if not callable(getattr(hook, 'get_pandas_df', None)):
         raise AirflowException(
             "This hook is not supported. The hook class must have get_pandas_df method."
         )
     return hook
def connections_get(args):
    """Get a connection."""
    try:
        conn = BaseHook.get_connection(args.conn_id)
    except AirflowNotFoundException:
        raise SystemExit("Connection not found.")
    AirflowConsole().print_as(
        data=[conn],
        output=args.output,
        mapper=_connection_mapper,
    )
    def connection(self) -> Iterator[SwiftService]:
        """Setup the objectstore connection

        Yields:
            Iterator: An objectstore connection

        """
        options = None
        if self.swift_conn_id != "swift_default":
            options = {}
            connection = BaseHook.get_connection(self.swift_conn_id)
            options["os_username"] = connection.login
            options["os_password"] = connection.password
            options["os_tenant_name"] = connection.host
        yield SwiftService(options=options)
Exemple #21
0
def _download_citi_bike_data(ts_nodash, **_):
    citibike_conn = BaseHook.get_connection(conn_id="citibike")

    url = f"http://{citibike_conn.host}:{citibike_conn.port}/recent/minute/15"
    response = requests.get(url,
                            auth=HTTPBasicAuth(citibike_conn.login,
                                               citibike_conn.password))
    data = response.json()

    s3_hook = S3Hook(aws_conn_id="s3")
    s3_hook.load_string(
        string_data=json.dumps(data),
        key=f"raw/citibike/{ts_nodash}.json",
        bucket_name="datalake",
    )
Exemple #22
0
 def _get_hook(self) -> DbApiHook:
     self.log.debug("Get connection for %s", self.sql_conn_id)
     conn = BaseHook.get_connection(self.sql_conn_id)
     if Version(version) >= Version('2.3'):
         # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3.
         hook = conn.get_hook(hook_params=self.sql_hook_params)  # ignore airflow compat check
     else:
         # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
         # when "apache-airflow-providers-slack" will depend on Airflow >= 2.3.
         hook = _backported_get_hook(conn, hook_params=self.sql_hook_params)
     if not callable(getattr(hook, 'get_pandas_df', None)):
         raise AirflowException(
             "This hook is not supported. The hook class must have get_pandas_df method."
         )
     return hook
Exemple #23
0
    def _hook(self):
        """Get DB Hook based on connection type"""
        self.log.debug("Get connection for %s", self.conn_id)
        conn = BaseHook.get_connection(self.conn_id)

        hook = conn.get_hook(hook_params=self.hook_params)
        if not isinstance(hook, DbApiHook):
            raise AirflowException(
                f'The connection type is not supported by {self.__class__.__name__}. '
                f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
            )

        if self.database:
            hook.schema = self.database

        return hook
Exemple #24
0
 def _get_hook(self):
     conn = BaseHook.get_connection(self.conn_id)
     if Version(version) >= Version('2.3'):
         # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3.
         hook = conn.get_hook(
             hook_params=self.hook_params)  # ignore airflow compat check
     else:
         # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
         # when "apache-airflow-providers-common-sql" will depend on Airflow >= 2.3.
         hook = _backported_get_hook(conn, hook_params=self.hook_params)
     if not isinstance(hook, DbApiHook):
         raise AirflowException(
             f'The connection type is not supported by {self.__class__.__name__}. '
             f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
         )
     return hook
Exemple #25
0
    def _get_hook(self):
        self.log.debug("Get connection for %s", self.conn_id)
        conn = BaseHook.get_connection(self.conn_id)

        if conn.conn_type not in ALLOWED_CONN_TYPE:
            raise AirflowException(
                "The connection type is not supported by BranchSQLOperator.\
                Supported connection types: {}".format(
                    list(ALLOWED_CONN_TYPE)))

        if not self._hook:
            self._hook = conn.get_hook()
            if self.database:
                self._hook.schema = self.database

        return self._hook
Exemple #26
0
 def execute(self, context: 'Context'):
     self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
     hook = CloudSQLDatabaseHook(
         gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
         gcp_conn_id=self.gcp_conn_id,
         default_gcp_project_id=self.gcp_connection.extra_dejson.get(
             'extra__google_cloud_platform__project'),
     )
     hook.validate_ssl_certs()
     connection = hook.create_connection()
     hook.validate_socket_path_length()
     database_hook = hook.get_database_hook(connection=connection)
     try:
         self._execute_query(hook, database_hook)
     finally:
         hook.cleanup_database_hook()
Exemple #27
0
    def poke(self, context: dict) -> bool:

        conn = BaseHook.get_connection(self.qubole_conn_id)
        Qubole.configure(api_token=conn.password, api_url=conn.host)

        self.log.info('Poking: %s', self.data)

        status = False
        try:
            status = self.sensor_class.check(self.data)  # type: ignore[attr-defined]
        except Exception as e:
            self.log.exception(e)
            status = False

        self.log.info('Status of this Poke: %s', status)

        return status
Exemple #28
0
def dag_success_slack_alert(context):
    slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password
    slack_msg = """
            :large_green_circle: DAG Succeeded.
            *Dag*: {dag}
            *Execution Time*: {exec_date}
            *Log Url*: {log_url}
            """.format(
        dag=context.get('task_instance').dag_id,
        ti=context.get('task_instance'),
        exec_date=context.get('execution_date'),
        log_url=context.get('task_instance').log_url,
    )
    failed_alert = SlackWebhookOperator(task_id='slack_test',
                                        http_conn_id='slack',
                                        webhook_token=slack_webhook_token,
                                        message=slack_msg,
                                        username='******')
    return failed_alert.execute(context=context)
Exemple #29
0
    def __init__(self,
                 lambda_function_name,
                 airflow_context_to_lambda_payload=None,
                 additional_payload=None,
                 aws_conn_id='aws_default',
                 read_timeout=2000,
                 *args,
                 **kwargs):
        """
        Trigger AWS Lambda function
        :param lambda_function_name: name of Lambda function
        :param airflow_context_to_lambda_payload: function extracting fields from Airflow context to Lambda payload
        :param additional_payload: additional parameters for Lambda payload
        :param aws_conn_id: aws connection id in order to call Lambda function
        :param read_timeout: read time in order to wait Resource response before closing connection
        :param args:
        :param kwargs:
        """
        super(ExecuteLambdaOperator, self).__init__(*args, **kwargs)
        if additional_payload is None:
            additional_payload = {}
        self.airflow_context_to_lambda_payload = airflow_context_to_lambda_payload

        self.additional_payload = additional_payload
        self.lambda_function_name = lambda_function_name
        if read_timeout is not None:
            print('check read_timeout')
            print(read_timeout)
            config = Config(read_timeout=read_timeout,
                            retries={'max_attempts': 0})
        else:
            config = Config(retries={'max_attempts': 0})
        if aws_conn_id is not None:
            connection = BaseHook.get_connection(aws_conn_id)
            self.lambda_client = boto3.client(
                'lambda',
                aws_access_key_id=connection.login,
                aws_secret_access_key=connection.password,
                config=config,
                region_name=connection.extra_dejson.get('region_name'))
        else:
            raise AttributeError('Please pass a valid aws_connection_id')
Exemple #30
0
    def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        ti = TaskInstance(task=operator, execution_date=dttm)
        conn = BaseHook.get_connection(
            getattr(operator, "qubole_conn_id", None)
            or operator.kwargs['qubole_conn_id']  # type: ignore[attr-defined]
        )
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='
        qds_command_id = ti.xcom_pull(task_ids=operator.task_id, key='qbol_cmd_id')
        url = host + str(qds_command_id) if qds_command_id else ''
        return url