def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE', database='TEST_DATABASE', output_location='s3://test_s3_bucket/', client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', sleep_time=0, max_tries=3, dag=self.dag)
class TestAWSAthenaOperator(unittest.TestCase): def setUp(self): args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE, } self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args, schedule_interval='@once') self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE', database='TEST_DATABASE', output_location='s3://test_s3_bucket/', client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', sleep_time=0, max_tries=3, dag=self.dag) def test_init(self): self.assertEqual(self.athena.task_id, MOCK_DATA['task_id']) self.assertEqual(self.athena.query, MOCK_DATA['query']) self.assertEqual(self.athena.database, MOCK_DATA['database']) self.assertEqual(self.athena.aws_conn_id, 'aws_default') self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token']) self.assertEqual(self.athena.sleep_time, 0) self.assertEqual(self.athena.hook.sleep_time, 0) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 1) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=(None, None,)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'get_state_change_reason') @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status, mock_get_state_change_reason): with self.assertRaises(Exception): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 2) self.assertEqual(mock_get_state_change_reason.call_count, 1) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "CANCELLED",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "RUNNING",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status): with self.assertRaises(Exception): self.athena.execute(None) mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, MOCK_DATA['client_request_token'], MOCK_DATA['workgroup']) self.assertEqual(mock_check_query_status.call_count, 3) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID) @mock.patch.object(AWSAthenaHook, 'get_conn') def test_xcom_push_and_pull(self, mock_conn, mock_run_query, mock_check_query_status): ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow()) ti.run() self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'), ATHENA_QUERY_ID)
schedule_interval=None, start_date=datetime(2021, 1, 1), dagrun_timeout=timedelta(minutes=60), tags=['example'], catchup=False, ) as dag: # [START howto_athena_operator_and_sensor] # Using a task-decorated function to create a CSV file in S3 add_sample_data_to_s3 = add_sample_data_to_s3() create_table = AWSAthenaOperator( task_id='setup__create_table', query=QUERY_CREATE_TABLE, database=ATHENA_DATABASE, output_location=f's3://{S3_BUCKET}/{S3_KEY}', sleep_time=30, max_tries=None, aws_conn_id=AWS_CONN_ID, ) read_table = AWSAthenaOperator( task_id='query__read_table', query=QUERY_READ_TABLE, database=ATHENA_DATABASE, output_location=f's3://{S3_BUCKET}/{S3_KEY}', sleep_time=30, max_tries=None, aws_conn_id=AWS_CONN_ID, )
"api_conn_id": "movielens", "s3_conn_id": "my_aws_conn", "s3_bucket": os.environ["RATINGS_BUCKET"], }, ) trigger_crawler = GlueTriggerCrawlerOperator( aws_conn_id="my_aws_conn", task_id="trigger_crawler", crawler_name=os.environ["CRAWLER_NAME"], ) rank_movies = AWSAthenaOperator( task_id="rank_movies", aws_conn_id="my_aws_conn", database="airflow", query=""" SELECT movieid, AVG(rating) as avg_rating, COUNT(*) as num_ratings FROM ( SELECT movieid, rating, CAST(from_unixtime(timestamp) AS DATE) AS date FROM ratings ) WHERE date <= DATE('{{ ds }}') GROUP BY movieid ORDER BY avg_rating DESC """, output_location=f"s3://{os.environ['RANKINGS_BUCKET']}/{{{{ds}}}}", ) fetch_ratings >> trigger_crawler >> rank_movies
def get_s3_to_athena_dag(parent_dag_name, task_id, s3_key, d_sqls, *args, **kwargs): dag = DAG(f"{parent_dag_name}.{task_id}", **kwargs) s_s3 = f"s3://{Variable.get('S3_DATA_BUCKET_NAME')}/{s3_key}/" s_output_s3 = f"s3://{Variable.get('S3_CODES_BUCKET_NAME')}/" create_db = AWSAthenaOperator(task_id='Create_database', query=d_sqls['createdb'], output_location=s_output_s3, database='processed', aws_conn_id='aws_credentials', dag=dag) # drop_table = AWSAthenaOperator( # task_id='Drop_table', # query=d_sqls['drop'], # output_location=s_output_s3, # database='processed', # aws_conn_id='aws_credentials', # dag=dag # ) create_table = AWSAthenaOperator(task_id='Create_table', query=d_sqls['create'].format(s_s3), output_location=s_output_s3, database='processed', aws_conn_id='aws_credentials', dag=dag) # bulk_insert_table = AWSAthenaOperator( # task_id='Insert_data_on_table', # query=d_sqls['load'], # output_location=s_output_s3, # database='processed', # aws_conn_id='aws_credentials', # dag=dag # ) partition_insert_table = AthenaPartitionInsert( task_id='Insert_data_into_table', query=d_sqls['load2'], output_location=s_output_s3, database='processed', aws_conn_id='aws_credentials', dag=dag) s_sql = ("SELECT COUNT(*) " f" FROM processed.{d_sqls['subkey']}" " WHERE intdate > {}0001" " AND intdate < {}2000") check_data_inserted = AthenaDataQuality(task_id='Fetch_data_from_table', query=s_sql, output_location=s_output_s3, database='processed', aws_conn_id='aws_credentials', dag=dag) # create_db >> drop_table >> create_table # create_table >> bulk_insert_table >> check_data_inserted create_db >> create_table >> partition_insert_table partition_insert_table >> check_data_inserted return dag