Exemple #1
0
    def setUp(self):
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }

        self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once',
                       default_args=args,
                       schedule_interval='@once')
        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=0, max_tries=3, dag=self.dag)
Exemple #2
0
class TestAWSAthenaOperator(unittest.TestCase):

    def setUp(self):
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }

        self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once',
                       default_args=args,
                       schedule_interval='@once')
        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=0, max_tries=3, dag=self.dag)

    def test_init(self):
        self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
        self.assertEqual(self.athena.query, MOCK_DATA['query'])
        self.assertEqual(self.athena.database, MOCK_DATA['database'])
        self.assertEqual(self.athena.aws_conn_id, 'aws_default')
        self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token'])
        self.assertEqual(self.athena.sleep_time, 0)

        self.assertEqual(self.athena.hook.sleep_time, 0)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 1)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=(None, None,))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'get_state_change_reason')
    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status,
                                    mock_get_state_change_reason):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 2)
        self.assertEqual(mock_get_state_change_reason.call_count, 1)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "CANCELLED",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "RUNNING",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_xcom_push_and_pull(self, mock_conn, mock_run_query, mock_check_query_status):
        ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow())
        ti.run()

        self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'),
                         ATHENA_QUERY_ID)
Exemple #3
0
        schedule_interval=None,
        start_date=datetime(2021, 1, 1),
        dagrun_timeout=timedelta(minutes=60),
        tags=['example'],
        catchup=False,
) as dag:
    # [START howto_athena_operator_and_sensor]

    # Using a task-decorated function to create a CSV file in S3
    add_sample_data_to_s3 = add_sample_data_to_s3()

    create_table = AWSAthenaOperator(
        task_id='setup__create_table',
        query=QUERY_CREATE_TABLE,
        database=ATHENA_DATABASE,
        output_location=f's3://{S3_BUCKET}/{S3_KEY}',
        sleep_time=30,
        max_tries=None,
        aws_conn_id=AWS_CONN_ID,
    )

    read_table = AWSAthenaOperator(
        task_id='query__read_table',
        query=QUERY_READ_TABLE,
        database=ATHENA_DATABASE,
        output_location=f's3://{S3_BUCKET}/{S3_KEY}',
        sleep_time=30,
        max_tries=None,
        aws_conn_id=AWS_CONN_ID,
    )
Exemple #4
0
            "api_conn_id": "movielens",
            "s3_conn_id": "my_aws_conn",
            "s3_bucket": os.environ["RATINGS_BUCKET"],
        },
    )

    trigger_crawler = GlueTriggerCrawlerOperator(
        aws_conn_id="my_aws_conn",
        task_id="trigger_crawler",
        crawler_name=os.environ["CRAWLER_NAME"],
    )

    rank_movies = AWSAthenaOperator(
        task_id="rank_movies",
        aws_conn_id="my_aws_conn",
        database="airflow",
        query="""
            SELECT movieid, AVG(rating) as avg_rating, COUNT(*) as num_ratings
            FROM (
                SELECT movieid, rating, CAST(from_unixtime(timestamp) AS DATE) AS date
                FROM ratings
            )
            WHERE date <= DATE('{{ ds }}')
            GROUP BY movieid
            ORDER BY avg_rating DESC
        """,
        output_location=f"s3://{os.environ['RANKINGS_BUCKET']}/{{{{ds}}}}",
    )

    fetch_ratings >> trigger_crawler >> rank_movies
Exemple #5
0
def get_s3_to_athena_dag(parent_dag_name, task_id, s3_key, d_sqls, *args,
                         **kwargs):

    dag = DAG(f"{parent_dag_name}.{task_id}", **kwargs)

    s_s3 = f"s3://{Variable.get('S3_DATA_BUCKET_NAME')}/{s3_key}/"
    s_output_s3 = f"s3://{Variable.get('S3_CODES_BUCKET_NAME')}/"

    create_db = AWSAthenaOperator(task_id='Create_database',
                                  query=d_sqls['createdb'],
                                  output_location=s_output_s3,
                                  database='processed',
                                  aws_conn_id='aws_credentials',
                                  dag=dag)

    # drop_table = AWSAthenaOperator(
    #     task_id='Drop_table',
    #     query=d_sqls['drop'],
    #     output_location=s_output_s3,
    #     database='processed',
    #     aws_conn_id='aws_credentials',
    #     dag=dag
    # )

    create_table = AWSAthenaOperator(task_id='Create_table',
                                     query=d_sqls['create'].format(s_s3),
                                     output_location=s_output_s3,
                                     database='processed',
                                     aws_conn_id='aws_credentials',
                                     dag=dag)

    # bulk_insert_table = AWSAthenaOperator(
    #     task_id='Insert_data_on_table',
    #     query=d_sqls['load'],
    #     output_location=s_output_s3,
    #     database='processed',
    #     aws_conn_id='aws_credentials',
    #     dag=dag
    # )
    partition_insert_table = AthenaPartitionInsert(
        task_id='Insert_data_into_table',
        query=d_sqls['load2'],
        output_location=s_output_s3,
        database='processed',
        aws_conn_id='aws_credentials',
        dag=dag)

    s_sql = ("SELECT COUNT(*) "
             f"  FROM processed.{d_sqls['subkey']}"
             "  WHERE intdate > {}0001"
             "  AND intdate < {}2000")

    check_data_inserted = AthenaDataQuality(task_id='Fetch_data_from_table',
                                            query=s_sql,
                                            output_location=s_output_s3,
                                            database='processed',
                                            aws_conn_id='aws_credentials',
                                            dag=dag)

    # create_db >> drop_table >> create_table
    # create_table >> bulk_insert_table >> check_data_inserted
    create_db >> create_table >> partition_insert_table
    partition_insert_table >> check_data_inserted

    return dag