def test_dbapi_get_uri(self):
     conn = BaseHook.get_connection(conn_id='test_uri')
     hook = conn.get_hook()
     assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == hook.get_uri()
     conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
     hook2 = conn2.get_hook()
     assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
Exemple #2
0
    def execute(self, context: Context):
        source_hook = BaseHook.get_hook(self.source_conn_id)
        destination_hook = BaseHook.get_hook(self.destination_conn_id)

        self.log.info("Extracting data from %s", self.source_conn_id)
        self.log.info("Executing: \n %s", self.sql)
        get_records = getattr(source_hook, 'get_records', None)
        if not callable(get_records):
            raise RuntimeError(
                f"Hook for connection {self.source_conn_id!r} "
                f"({type(source_hook).__name__}) has no `get_records` method")
        else:
            results = get_records(self.sql)

        if self.preoperator:
            run = getattr(destination_hook, 'run', None)
            if not callable(run):
                raise RuntimeError(
                    f"Hook for connection {self.destination_conn_id!r} "
                    f"({type(destination_hook).__name__}) has no `run` method")
            self.log.info("Running preoperator")
            self.log.info(self.preoperator)
            run(self.preoperator)

        insert_rows = getattr(destination_hook, 'insert_rows', None)
        if not callable(insert_rows):
            raise RuntimeError(
                f"Hook for connection {self.destination_conn_id!r} "
                f"({type(destination_hook).__name__}) has no `insert_rows` method"
            )
        self.log.info("Inserting rows into %s", self.destination_conn_id)
        insert_rows(table=self.destination_table,
                    rows=results,
                    **self.insert_args)
Exemple #3
0
def task_fail_slack_alert(context):
    """
    Callback task that can be used in DAG to alert of failure task completion
    Args:
        context (dict): Context variable passed in from Airflow
    Returns:
        None: Calls the SlackWebhookOperator execute method internally
    """
    conection = BaseHook.get_connection(SLACK_CONN_ID)
    slack_webhook_token = conection.password
    slack_msg = """*Status*: :x: Task Failed\n*Task*: {task}\n*Dag*: {dag}\n*Execution Time*: {exec_date}\n*Log Url*: {log_url}""".format(
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )
    if conection.extra_dejson.get('users'):
        slack_msg = slack_msg + '\n' + conection.extra_dejson.get('users')

    failed_alert = SlackWebhookOperator(
        task_id="slack_task",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=slack_msg,
        username="******",
        link_names=True
    )

    return failed_alert.execute(context=context)
Exemple #4
0
def get_minio_object(pandas_read_callable,
                     bucket,
                     paths,
                     pandas_read_callable_kwargs=None):
    s3_conn = BaseHook.get_connection(conn_id="s3")
    minio_client = Minio(
        s3_conn.extra_dejson["host"].split("://")[1],
        access_key=s3_conn.extra_dejson["aws_access_key_id"],
        secret_key=s3_conn.extra_dejson["aws_secret_access_key"],
        secure=False,
    )

    if isinstance(paths, str):
        if paths.startswith("[") and paths.endswith("]"):
            paths = eval(paths)
        else:
            paths = [paths]

    if pandas_read_callable_kwargs is None:
        pandas_read_callable_kwargs = {}

    dfs = []
    for path in paths:
        minio_object = minio_client.get_object(bucket_name=bucket,
                                               object_name=path)
        df = pandas_read_callable(minio_object, **pandas_read_callable_kwargs)
        dfs.append(df)
    return pd.concat(dfs)
Exemple #5
0
def send_mime_email(
    e_from: str,
    e_to: Union[str, List[str]],
    mime_msg: MIMEMultipart,
    conn_id: str = "smtp_default",
    dryrun: bool = False,
) -> None:
    """Send MIME email."""
    smtp_host = conf.get('smtp', 'SMTP_HOST')
    smtp_port = conf.getint('smtp', 'SMTP_PORT')
    smtp_starttls = conf.getboolean('smtp', 'SMTP_STARTTLS')
    smtp_ssl = conf.getboolean('smtp', 'SMTP_SSL')
    smtp_retry_limit = conf.getint('smtp', 'SMTP_RETRY_LIMIT')
    smtp_timeout = conf.getint('smtp', 'SMTP_TIMEOUT')
    smtp_user = None
    smtp_password = None

    if conn_id is not None:
        try:
            from airflow.hooks.base import BaseHook

            airflow_conn = BaseHook.get_connection(conn_id)
            smtp_user = airflow_conn.login
            smtp_password = airflow_conn.password
        except AirflowException:
            pass
    if smtp_user is None or smtp_password is None:
        warnings.warn(
            "Fetching SMTP credentials from configuration variables will be deprecated in a future "
            "release. Please set credentials using a connection instead.",
            PendingDeprecationWarning,
            stacklevel=2,
        )
        try:
            smtp_user = conf.get('smtp', 'SMTP_USER')
            smtp_password = conf.get('smtp', 'SMTP_PASSWORD')
        except AirflowConfigException:
            log.debug(
                "No user/password found for SMTP, so logging in with no authentication."
            )

    if not dryrun:
        for attempt in range(1, smtp_retry_limit + 1):
            log.info("Email alerting: attempt %s", str(attempt))
            try:
                smtp_conn = _get_smtp_connection(smtp_host, smtp_port,
                                                 smtp_timeout, smtp_ssl)
            except smtplib.SMTPServerDisconnected:
                if attempt < smtp_retry_limit:
                    continue
                raise

            if smtp_starttls:
                smtp_conn.starttls()
            if smtp_user and smtp_password:
                smtp_conn.login(smtp_user, smtp_password)
            log.info("Sent an alert email to %s", e_to)
            smtp_conn.sendmail(e_from, e_to, mime_msg.as_string())
            smtp_conn.quit()
            break
Exemple #6
0
    def get_link(
        self,
        operator,
        dttm=None,
        *,
        ti_key: Optional["TaskInstanceKey"] = None,
    ) -> str:
        if ti_key:
            run_id = XCom.get_value(key="run_id", ti_key=ti_key)
        else:
            assert dttm
            run_id = XCom.get_one(
                key="run_id",
                dag_id=operator.dag.dag_id,
                task_id=operator.task_id,
                execution_date=dttm,
            )

        conn = BaseHook.get_connection(operator.azure_data_factory_conn_id)
        subscription_id = conn.extra_dejson[
            "extra__azure_data_factory__subscriptionId"]
        # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory
        # connection or passed directly to the operator.
        resource_group_name = operator.resource_group_name or conn.extra_dejson.get(
            "extra__azure_data_factory__resource_group_name")
        factory_name = operator.factory_name or conn.extra_dejson.get(
            "extra__azure_data_factory__factory_name")
        url = (
            f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}"
            f"?factory=/subscriptions/{subscription_id}/"
            f"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/"
            f"factories/{factory_name}")

        return url
Exemple #7
0
def slack_dag_failure_notification(context):
    slack_webhook_token = BaseHook.get_connection("slack").password
    icon_color = (":red_circle" if configuration.ENVIRONMENT.lower()
                  == "production" else ":yellow_circle")

    msg = """
          {icon_color}: Task Failed. 
          *Task*: {task}  
          *Dag*: {dag}
          *Execution Time*: {exec_date}  
          *Log Url*: {log_url} 
          """.format(
        icon_color=icon_color,
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )

    failed_alert = SlackWebhookOperator(
        task_id="slack_failed_notification",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=msg,
        username="******",
    )

    return failed_alert.execute(context=context)
    def mean_fare_per_class(titanic_df_json_str: dict):
        """
        # Mean_fare_per_class task
        Takes the str of json data from XCOM,
        converts it to Pandas dataframe and makes some df aggregations
        Then dataframe is converted to the list of tuples and sent to external
        local DB
        """
        # преобразуем в pandas dataframe и изменяем группировками, агрегациями:
        titanic_df = pd.read_json(titanic_df_json_str['titanic_df_json_str'])
        df = titanic_df \
            .groupby(['Pclass']) \
            .agg({'Fare': 'mean'}) \
            .reset_index()

        # создаем кастом хук, коннектшн берем из предварительно созданного в UI:
        pg_hook = BaseHook.get_hook('postgres_default')

        # имя тааблицы в локальной БД предварительно задано в UI в Variables. Извлекаем:
        pg_table_name = Variable.get('mean_fares_table_name')

        # перемалываем датафрейм в список кортежей, приводим типы к стандартным (int и float):
        pg_rows = list(df.to_records(index=False))
        pg_rows_conv = [(int(t[0]), float(t[1])) for t in pg_rows]

        # извлекаем названия полей(колонок) датафрейма:
        pg_columns = list(df.columns)

        # отправляем данные в локальную БД:
        pg_hook.insert_rows(table=pg_table_name,
                            rows=pg_rows_conv,
                            target_fields=pg_columns,
                            commit_every=0,
                            replace=False)
Exemple #9
0
    def get_link(
        self,
        operator: "AbstractOperator",
        dttm: Optional[datetime] = None,
        *,
        ti_key: Optional["TaskInstanceKey"] = None,
    ) -> str:
        """
        Get link to qubole command result page.

        :param operator: operator
        :param dttm: datetime
        :return: url link
        """
        conn = BaseHook.get_connection(
            getattr(operator, "qubole_conn_id", None)
            or operator.kwargs['qubole_conn_id']  # type: ignore[attr-defined]
        )
        if conn and conn.host:
            host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
        else:
            host = 'https://api.qubole.com/v2/analyze?command_id='
        if ti_key:
            qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key)
        else:
            assert dttm
            qds_command_id = XCom.get_one(
                key='qbol_cmd_id', dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
            )
        url = host + str(qds_command_id) if qds_command_id else ''
        return url
Exemple #10
0
 def test_dbapi_get_sqlalchemy_engine(self):
     conn = BaseHook.get_connection(conn_id='test_uri')
     hook = conn.get_hook()
     engine = hook.get_sqlalchemy_engine()
     assert isinstance(engine, sqlalchemy.engine.Engine)
     assert 'postgres://*****:*****@ec2.compute.com:5432/the_database' == str(
         engine.url)
Exemple #11
0
    def execute(self, context):
        source_hook = BaseHook.get_hook(self.source_conn_id)

        self.log.info("Extracting data from %s", self.source_conn_id)
        self.log.info("Executing: \n %s", self.sql)
        results = source_hook.get_records(self.sql)

        destination_hook = BaseHook.get_hook(self.destination_conn_id)
        if self.preoperator:
            self.log.info("Running preoperator")
            self.log.info(self.preoperator)
            destination_hook.run(self.preoperator)

        self.log.info("Inserting rows into %s", self.destination_conn_id)
        destination_hook.insert_rows(table=self.destination_table,
                                     rows=results)
    def __init__(self,
                 glue_crawler_name,
                 aws_conn_id='aws_default',
                 read_timeout=2000,
                 *args,
                 **kwargs):
        """
        Trigger AWS Glue crawler
        :param glue_crawler_name: name of Glue crawler
        :param read_timeout: read time in order to wait Resource response before closing connection
        :param args:
        :param kwargs:
        """
        super(GlueCrawlerOperator, self).__init__(*args, **kwargs)

        self.glue_crawler_name = glue_crawler_name
        if read_timeout is not None:
            print('check read_timeout')
            print(read_timeout)
            config = Config(read_timeout=read_timeout,
                            retries={'max_attempts': 0})
        else:
            config = Config(retries={'max_attempts': 0})
        if aws_conn_id is not None:
            connection = BaseHook.get_connection(aws_conn_id)
            self.client = boto3.client(
                'glue',
                aws_access_key_id=connection.login,
                aws_secret_access_key=connection.password,
                config=config,
                region_name=connection.extra_dejson.get('region_name'))
        else:
            raise AttributeError('Please pass a valid aws_connection_id')
Exemple #13
0
    def test_check_operators(self):

        conn_id = "sqlite_default"

        captain_hook = BaseHook.get_hook(conn_id=conn_id)  # quite funny :D
        captain_hook.run("CREATE TABLE operator_test_table (a, b)")
        captain_hook.run("insert into operator_test_table values (1,2)")

        self.dag.create_dagrun(run_type=DagRunType.MANUAL,
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE)
        op = CheckOperator(task_id='check',
                           sql="select count(*) from operator_test_table",
                           conn_id=conn_id,
                           dag=self.dag)

        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        op = ValueCheckOperator(
            task_id='value_check',
            pass_value=95,
            tolerance=0.1,
            conn_id=conn_id,
            sql="SELECT 100",
            dag=self.dag,
        )
        op.run(start_date=DEFAULT_DATE,
               end_date=DEFAULT_DATE,
               ignore_ti_state=True)

        captain_hook.run("drop table operator_test_table")
Exemple #14
0
def slack_success_notification(context):
    slack_webhook_token = BaseHook.get_connection("slack").password

    msg = """
          :green_circle: Task Successful. 
          *Task*: {task}  
          *Dag*: {dag} 
          *Execution Time*: {exec_date}  
          *Log Url*: {log_url} 
          """.format(
        task=context.get("task_instance").task_id,
        dag=context.get("task_instance").dag_id,
        ti=context.get("task_instance"),
        exec_date=context.get("execution_date"),
        log_url=context.get("task_instance").log_url,
    )

    success_alert = SlackWebhookOperator(
        task_id="slack_success_notification",
        http_conn_id="slack",
        webhook_token=slack_webhook_token,
        message=msg,
        username="******",
    )

    return success_alert.execute(context=context)
Exemple #15
0
def create_connection():
    try:
        db_con = BaseHook.get_connection('skyeng_db')
        return db_con.connect(db_con.host)
    except Exception as e:
        logging.error(e)
    return None
Exemple #16
0
def _post_sendgrid_mail(mail_data: Dict,
                        conn_id: str = "sendgrid_default") -> None:
    api_key = None
    try:
        conn = BaseHook.get_connection(conn_id)
        api_key = conn.password
    except AirflowException:
        pass
    if api_key is None:
        warnings.warn(
            "Fetching Sendgrid credentials from environment variables will be deprecated in a future "
            "release. Please set credentials using a connection instead.",
            PendingDeprecationWarning,
            stacklevel=2,
        )
        api_key = os.environ.get('SENDGRID_API_KEY')
    sendgrid_client = sendgrid.SendGridAPIClient(api_key=api_key)
    response = sendgrid_client.client.mail.send.post(request_body=mail_data)
    # 2xx status code.
    if 200 <= response.status_code < 300:
        log.info(
            'Email with subject %s is successfully sent to recipients: %s',
            mail_data['subject'],
            mail_data['personalizations'],
        )
    else:
        log.error(
            'Failed to send out email with subject %s, status code: %s',
            mail_data['subject'],
            response.status_code,
        )
def task_fail_slack_alert(context):
    slack_webhook_token = BaseHook.get_connection(SLACK_CONN_ID).password
    if var_loader.get_git_user() != "cloud-bulldozer":
        print("Task Failed")
        return
    if context.get('task_instance').task_id != "final_status":
        print(context.get('task_instance').task_id, "Task failed")
        return

    slack_msg = """
            :red_circle: DAG Failed {mem} 
            *Task*: {task}  
            *Dag*: {dag} 
            *Execution Time*: {exec_date}  
            *Log Url*: {log_url} 
            """.format(
        task=context.get('task_instance').task_id,
        dag=context.get('task_instance').dag_id,
        mem=alert_members(context),
        ti=context.get('task_instance'),
        exec_date=context.get('execution_date'),
        log_url=get_hyperlink(context),
    )
    failed_alert = SlackWebhookOperator(task_id='slack_test',
                                        http_conn_id='slack',
                                        webhook_token=slack_webhook_token,
                                        message=slack_msg,
                                        username='******',
                                        link_names=True)
    return failed_alert.execute(context=context)
Exemple #18
0
    def get_db_hook(self):
        """
        Get the database hook for the connection.

        :return: the database hook object.
        :rtype: DbApiHook
        """
        return BaseHook.get_hook(conn_id=self.conn_id)
Exemple #19
0
def kerberos_auth(conn_id: str) -> 'hdfs.ext.kerberos.KerberosClient':
    logging.info(f'Getting kerberos ticket ...')
    login = BaseHook.get_connection(conn_id).login
    password = BaseHook.get_connection(conn_id).password
    host = BaseHook.get_connection(conn_id).host
    port = BaseHook.get_connection(conn_id).port

    passwd = subprocess.Popen(('echo', password), stdout=subprocess.PIPE)
    subprocess.call(('kinit', login), stdin=passwd.stdout)

    session = requests.Session()
    session.verify = False

    return KerberosClient(f'https://{host}:{port}',
                          mutual_auth="REQUIRED",
                          max_concurrency=256,
                          session=session)
Exemple #20
0
def get_airflow_gcp_credentials(conn_id: str) -> tuple:
    connection = BaseHook.get_connection(conn_id)
    conn_extra = connection.extra_dejson
    key_access_keyfile = "extra__google_cloud_platform__keyfile_dict"
    keyfile_str = conn_extra[key_access_keyfile]
    keyfile_dict = json.loads(keyfile_str)
    project_id = keyfile_dict["project_id"]
    return (project_id, keyfile_dict)
Exemple #21
0
 def _get_hook(self):
     conn = BaseHook.get_connection(self.conn_id)
     hook = conn.get_hook(hook_params=self.hook_params)
     if not isinstance(hook, DbApiHook):
         raise AirflowException(
             f'The connection type is not supported by {self.__class__.__name__}. '
             f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
         )
     return hook
def pg_params(conn_id: str = "postgres_default") -> str:
    """
    Args:
        conn_id: database connection that is provided with default parameters
    returns:
        connection string with default params
    """
    connection_uri = BaseHook.get_connection(conn_id).get_uri().split("?")[0]
    return f"{connection_uri} -X --set ON_ERROR_STOP=1"
Exemple #23
0
 def _get_hook(self) -> DbApiHook:
     self.log.debug("Get connection for %s", self.sql_conn_id)
     conn = BaseHook.get_connection(self.sql_conn_id)
     hook = conn.get_hook()
     if not callable(getattr(hook, 'get_pandas_df', None)):
         raise AirflowException(
             "This hook is not supported. The hook class must have get_pandas_df method."
         )
     return hook
def connections_get(args):
    """Get a connection."""
    try:
        conn = BaseHook.get_connection(args.conn_id)
    except AirflowNotFoundException:
        raise SystemExit("Connection not found.")
    AirflowConsole().print_as(
        data=[conn],
        output=args.output,
        mapper=_connection_mapper,
    )
Exemple #25
0
 def _get_hook(self) -> DbApiHook:
     self.log.debug("Get connection for %s", self.sql_conn_id)
     conn = BaseHook.get_connection(self.sql_conn_id)
     if Version(version) >= Version('2.3'):
         # "hook_params" were introduced to into "get_hook()" only in Airflow 2.3.
         hook = conn.get_hook(hook_params=self.sql_hook_params)  # ignore airflow compat check
     else:
         # For supporting Airflow versions < 2.3, we backport "get_hook()" method. This should be removed
         # when "apache-airflow-providers-slack" will depend on Airflow >= 2.3.
         hook = _backported_get_hook(conn, hook_params=self.sql_hook_params)
     if not callable(getattr(hook, 'get_pandas_df', None)):
         raise AirflowException(
             "This hook is not supported. The hook class must have get_pandas_df method."
         )
     return hook
    def connection(self) -> Iterator[SwiftService]:
        """Setup the objectstore connection

        Yields:
            Iterator: An objectstore connection

        """
        options = None
        if self.swift_conn_id != "swift_default":
            options = {}
            connection = BaseHook.get_connection(self.swift_conn_id)
            options["os_username"] = connection.login
            options["os_password"] = connection.password
            options["os_tenant_name"] = connection.host
        yield SwiftService(options=options)
Exemple #27
0
def _download_citi_bike_data(ts_nodash, **_):
    citibike_conn = BaseHook.get_connection(conn_id="citibike")

    url = f"http://{citibike_conn.host}:{citibike_conn.port}/recent/minute/15"
    response = requests.get(url,
                            auth=HTTPBasicAuth(citibike_conn.login,
                                               citibike_conn.password))
    data = response.json()

    s3_hook = S3Hook(aws_conn_id="s3")
    s3_hook.load_string(
        string_data=json.dumps(data),
        key=f"raw/citibike/{ts_nodash}.json",
        bucket_name="datalake",
    )
Exemple #28
0
    def _hook(self):
        """Get DB Hook based on connection type"""
        self.log.debug("Get connection for %s", self.conn_id)
        conn = BaseHook.get_connection(self.conn_id)

        hook = conn.get_hook(hook_params=self.hook_params)
        if not isinstance(hook, DbApiHook):
            raise AirflowException(
                f'The connection type is not supported by {self.__class__.__name__}. '
                f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
            )

        if self.database:
            hook.schema = self.database

        return hook
Exemple #29
0
    def _get_hook(self):
        self.log.debug("Get connection for %s", self.conn_id)
        conn = BaseHook.get_connection(self.conn_id)

        if conn.conn_type not in ALLOWED_CONN_TYPE:
            raise AirflowException(
                "The connection type is not supported by BranchSQLOperator.\
                Supported connection types: {}".format(
                    list(ALLOWED_CONN_TYPE)))

        if not self._hook:
            self._hook = conn.get_hook()
            if self.database:
                self._hook.schema = self.database

        return self._hook
Exemple #30
0
 def execute(self, context: 'Context'):
     self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
     hook = CloudSQLDatabaseHook(
         gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
         gcp_conn_id=self.gcp_conn_id,
         default_gcp_project_id=self.gcp_connection.extra_dejson.get(
             'extra__google_cloud_platform__project'),
     )
     hook.validate_ssl_certs()
     connection = hook.create_connection()
     hook.validate_socket_path_length()
     database_hook = hook.get_database_hook(connection=connection)
     try:
         self._execute_query(hook, database_hook)
     finally:
         hook.cleanup_database_hook()