Пример #1
0
    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
        self._create_clusters()
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='redshift')
        client_from_hook = hook.get_conn()

        clusters = client_from_hook.describe_clusters()['Clusters']
        self.assertEqual(len(clusters), 2)
Пример #2
0
    def get_iam_token(self, conn):
        """
        Uses AWSHook to retrieve a temporary password to connect to MySQL
        Port is required. If none is provided, default 3306 is used
        """
        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

        aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
        aws_hook = AwsBaseHook(aws_conn_id, client_type='rds')
        if conn.port is None:
            port = 3306
        else:
            port = conn.port
        client = aws_hook.get_conn()
        token = client.generate_db_auth_token(conn.host, port, conn.login)
        return token, port
Пример #3
0
    def test_run_example_gcp_vision_autogenerated_id_dag(self):
        mock_connection = Connection(
            conn_type="aws",
            extra=json.dumps({
                "role_arn":
                ROLE_ANR,
                "assume_role_method":
                "assume_role_with_web_identity",
                "assume_role_with_web_identity_federation":
                'google',
                "assume_role_with_web_identity_federation_audience":
                AUDIENCE,
            }),
        )

        with mock.patch.dict(
                'os.environ',
                AIRFLOW_CONN_AWS_DEFAULT=mock_connection.get_uri()):
            hook = AwsBaseHook(client_type='s3')

            client = hook.get_conn()
            response = client.list_buckets()
            assert 'Buckets' in response
    def execute(self, context):
        hook = AwsBaseHook(self._aws_conn_id,
                           client_type="glue",
                           region_name=self._region_name)
        glue_client = hook.get_conn()

        self.log.info("Triggering crawler")
        response = glue_client.start_crawler(Name=self._crawler_name)

        if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
            raise RuntimeError(
                "An error occurred while triggering the crawler: %r" %
                response)

        self.log.info("Waiting for crawler to finish")
        while True:
            time.sleep(1)

            crawler = glue_client.get_crawler(Name=self._crawler_name)
            crawler_state = crawler["Crawler"]["State"]

            if crawler_state == "READY":
                self.log.info("Crawler finished running")
                break
Пример #5
0
def airflow_notify_sns(context, **kwargs):
    """ 
    Publish Airflow Error Notification to a SNS Topic

    Parameters:
        context (dict): Airflow task execution context
    
    Returns:
        boto3 sns_client.publish() response
    """
    sns_client = AwsBaseHook(client_type="sns", aws_conn_id='aws_default')
    sns_topic_arn = Variable.get('airflow_notify_sns_arn', None)

    # Make variable required
    if sns_topic_arn is None:
        LOGGING.error("Variable [airflow_notify_sns_arn] not found in Airflow")
        return None

    # Message attributes
    subject = "Airflow task execution failed"
    message = get_message_text(context)

    # Sending message to topic
    LOGGING.info(f"Error message to send: {message}")
    LOGGING.info(f"Sending error message to SNS Topic ARN [{sns_topic_arn}]")
    try:
        response = sns_client.get_conn().publish(TopicArn=sns_topic_arn,
                                                 Subject=subject,
                                                 Message=message)
        LOGGING.info("Message successfully sent do SNS Topic")
        return response
    except Exception as ex:
        LOGGING.error(f"Error sending message to SNS: [{ex}]")
        return None

    return None