Ejemplo n.º 1
0
    def setUp(self):
        configuration.load_test_config()

        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=1, max_tries=3)
Ejemplo n.º 2
0
    def setUp(self):
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
        }

        self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once',
                       default_args=args,
                       schedule_interval='@once')
        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=1, max_tries=3, dag=self.dag)
    def setUp(self):
        configuration.load_test_config()

        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=1)
class TestAWSAthenaOperator(unittest.TestCase):

    def setUp(self):
        configuration.load_test_config()

        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=1)

    def test_init(self):
        self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
        self.assertEqual(self.athena.query, MOCK_DATA['query'])
        self.assertEqual(self.athena.database, MOCK_DATA['database'])
        self.assertEqual(self.athena.aws_conn_id, 'aws_default')
        self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token'])
        self.assertEqual(self.athena.sleep_time, 1)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 1)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 2)

    @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "CANCELLED",))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
                                               MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 3)
Ejemplo n.º 5
0
class TestAWSAthenaOperator(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()

        self.athena = AWSAthenaOperator(
            task_id='test_aws_athena_operator',
            query='SELECT * FROM TEST_TABLE',
            database='TEST_DATABASE',
            output_location='s3://test_s3_bucket/',
            client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
            sleep_time=1)

    def test_init(self):
        self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
        self.assertEqual(self.athena.query, MOCK_DATA['query'])
        self.assertEqual(self.athena.database, MOCK_DATA['database'])
        self.assertEqual(self.athena.aws_conn_id, 'aws_default')
        self.assertEqual(self.athena.client_request_token,
                         MOCK_DATA['client_request_token'])
        self.assertEqual(self.athena.sleep_time, 1)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=("SUCCESS", ))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_small_success_query(self, mock_conn, mock_run_query,
                                          mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 1)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "RUNNING",
                           "SUCCESS",
                       ))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_big_success_query(self, mock_conn, mock_run_query,
                                        mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "FAILED",
                       ))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failure_query(self, mock_conn, mock_run_query,
                                    mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 2)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "RUNNING",
                           "CANCELLED",
                       ))
    @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_cancelled_query(self, mock_conn, mock_run_query,
                                      mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'])
        self.assertEqual(mock_check_query_status.call_count, 3)
Ejemplo n.º 6
0
#from airflow.providers.amazon.aws.operators.athena
from datetime import timedelta

default_args = {
    'owner': 'Airflow',
    'depends_on_past': False,
    'start_date': datetime(2019, 12, 5),
    'email': ['*****@*****.**'],
    'email_on_failure': ['*****@*****.**'],
    'email_on_retry': ['*****@*****.**'],
    'retries': 2,
    'retry_delay': timedelta(minutes=5)
    #, 'queue': 'bash_queue',
    # 'pool': 'backfill',
    # 'priority_weight': 10,
    # 'end_date': datetime(2016, 1, 1),
}

with DAG(dag_id='simple_athena_query2',
         schedule_interval="@monthly",
         start_date=datetime(2019, 12, 5),
         default_args=default_args) as dag:

    run_query = AWSAthenaOperator(
        task_id='run_query',
        query=
        "select cast(day / 100 as varchar(6)) , count(1) from sampledb.minuts_info where cast( replace(  cast( date_trunc('month', DATE('{{ ds }}')) as varchar(7))  , '-', '') as varchar(6)) = cast(day / 100 as varchar(6)) group by cast(day / 100 as varchar(6)) ",
        output_location='s3://matchestest/airflow_athena/',
        database='sampledb')

##task_one >> task_two >> [task_two_1, task_two_2, task_two_3] >> end
Ejemplo n.º 7
0
#from airflow.providers.amazon.aws.operators.athena
from datetime import timedelta

default_args = {
    'owner': 'Airflow',
    'depends_on_past': False,
    'start_date': datetime(2019, 12, 5),
    'email': ['*****@*****.**'],
    'email_on_failure': ['*****@*****.**'],
    'email_on_retry': ['*****@*****.**'],
    'retries': 2,
    'retry_delay': timedelta(minutes=5)
    #, 'queue': 'bash_queue',
    # 'pool': 'backfill',
    # 'priority_weight': 10,
    # 'end_date': datetime(2016, 1, 1),
}

with DAG(dag_id='simple_athena_query3',
         schedule_interval="@monthly",
         start_date=datetime(2019, 12, 5),
         template_searchpath=['/sqls/'],
         default_args=default_args) as dag:

    run_query = AWSAthenaOperator(
        task_id='run_query',
        query="group_by_month.sql",
        output_location='s3://matchestest/airflow_athena/',
        database='sampledb')

##task_one >> task_two >> [task_two_1, task_two_2, task_two_3] >> end
athena_drop_output_table_query="DROP TABLE default.handson_output_parquet"

athena_create_input_table_query=f"CREATE EXTERNAL TABLE IF NOT EXISTS default.handson_input_csv(  deviceid string,  uuid bigint,  appid bigint,  country string,  year bigint,  month bigint,  day bigint,  hour bigint)ROW FORMAT DELIMITED  FIELDS TERMINATED BY ','STORED AS INPUTFORMAT  'org.apache.hadoop.mapred.TextInputFormat'OUTPUTFORMAT  'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'LOCATION  's3://{s3_bucket_name}/{input_path}/' TBLPROPERTIES (  'classification'='csv',  'delimiter'=',',  'skip.header.line.count'='1' )"

athena_ctas_new_table_query=f"CREATE TABLE \"default\".\"handson_output_parquet\"WITH (format = 'PARQUET',external_location='s3://{s3_bucket_name}/{output_path}/',parquet_compression = 'SNAPPY')AS SELECT *FROM \"default\".\"handson_input_csv\"WHERE deviceid = 'iphone' OR deviceid = 'android'"

def s3_bucket_cleaning_job():
    ##エラー発生追記#######
    ##raise Exception('エラーテスト')
    #######################
    s3 = boto3.resource('s3')
    bucket = s3.Bucket(s3_bucket_name)
    bucket.objects.filter(Prefix=output_path).delete()

with DAG(
    dag_id="etl_athena_job",
    description="etl athena DAG",
    default_args=args,
    schedule_interval="*/60 * * * *",
    catchup=False,
    tags=['handson']
) as dag:
    t1 = AWSAthenaOperator(task_id="athena_drop_output_table",query=athena_drop_output_table_query, database="default", output_location=athena_results)
    t2 = PythonOperator(task_id="s3_bucket_cleaning", python_callable=s3_bucket_cleaning_job)
    t3 = AWSAthenaOperator(task_id="athena_create_input_table",query=athena_create_input_table_query, database="default", output_location=athena_results)
    t4 = AWSAthenaOperator(task_id="athena_ctas_new_table",query=athena_ctas_new_table_query, database="default", output_location=athena_results)

    t1 >> t3 >> t4
    t2 >> t3 >> t4
	python_callable=drop_athena_export_table,
	dag=dag
    )

clear_export_folder = PythonOperator (
    task_id='clear_export_folder',
	provide_context=True,
	python_callable=clear_export_folder,
	dag=dag
    )

export_athena_scifi_table = AWSAthenaOperator(
    task_id="export_athena_scifi_table",
    #query=export_athena_scifi_table_query,
    query=export_athena_scifi_table_query2, 
    workgroup = "devday-demo", 
    database=athena_db,
    sleep_time = 60,
    output_location='s3://'+s3_dlake+"/"+athena_output+'export_athena_scifi_table'
    )


export_scifi_tofile = PythonOperator (
    task_id='export_scifi_tofile',
	provide_context=True,
	python_callable=export_scifi_tofile,
	dag=dag
    )

check_athena_export_table.set_upstream(disp_variables)
drop_athena_export_table.set_upstream(check_athena_export_table)
Ejemplo n.º 10
0
    # 'priority_weight': 10,
    # 'end_date': datetime(2016, 1, 1),
}

with DAG(
        dag_id='bet369_firstpart_goals_month',
        schedule_interval="@monthly",
        start_date=datetime(2019, 12, 5)
        #,template_searchpath = ['./sqls/']
        ,
        default_args=default_args) as dag:

    #need to include the load partitions of minutssss
    refresh_minuts = AWSAthenaOperator(
        task_id='refresh_parts_minuts_info',
        query="MSCK REPAIR TABLE sampledb.minuts_info;",
        output_location='s3://matchestest/airflow_athena/logs',
        database='sampledb')

    drop_tmp_minuts_info0 = AWSAthenaOperator(
        task_id='00_drop_tmp_minuts_info',
        query="DROP TABLE IF EXISTS sampledb.TMP_MONTH_NODUPS;",
        output_location='s3://matchestest/airflow_athena/logs',
        database='sampledb')

    tmp_minuts_info0 = AWSAthenaOperator(
        task_id='00_tmp_minuts_info',
        query="/sqls/00_month_tmp_minuts_info.sql",
        output_location='s3://matchestest/airflow_athena/logs',
        database='sampledb')
    timeout=20,
    poke_interval=5,
    soft_fail=True,
    # bucket_key=f"{GlobalArgs.S3_RAW_DATA_PREFIX}/movie_ratings_*",
    bucket_key=
    f"{GlobalArgs.S3_RAW_DATA_PREFIX}/dt={datetime.datetime.now().strftime('%Y_%m_%d')}/{GlobalArgs.S3_KEY_NAME}",
    bucket_name=GlobalArgs.S3_BKT_NAME,
    wildcard_match=True,
    s3_conn_id='aws_default',
    dag=redshift_ingestor_dag)

# Task to create Athena Database
create_athena_database_movie_ratings = AWSAthenaOperator(
    task_id="create_athena_database_movie_ratings",
    query=CREATE_ATHENA_DATABASE_MOVIES_QUERY,
    database=GlobalArgs.ATHENA_DB,
    output_location=
    f"s3://{GlobalArgs.S3_BKT_NAME}/{GlobalArgs.ATHENA_RESULTS}/create_athena_database_movie_ratings"
)

# Task to create Athena Table
create_athena_table_movie_ratings = AWSAthenaOperator(
    task_id="create_athena_table_movie_ratings",
    query=CREATE_ATHENA_TABLE_MOVIE_RATINGS_QUERY,
    database=GlobalArgs.ATHENA_DB,
    output_location=
    f"s3://{GlobalArgs.S3_BKT_NAME}/{GlobalArgs.ATHENA_RESULTS}/create_athena_table_movie_ratings"
)

# Task to move processed file
move_raw_files_to_processed_loc = S3CopyObjectOperator(
Ejemplo n.º 12
0
#set path to repartition.py file in Databricks catalog
notebook_params = {
    'new_cluster': etl_cluster,
    'notebook_task': {
        'notebook_path': '/path_to_file_in_databricks/repartition'
    }
}

run_process_data = DatabricksSubmitRunOperator(task_id='process_data',
                                               json=notebook_params,
                                               retries=2,
                                               dag=dag)

### V1 with PythonOperator which executes run_add_partitions func from athena.py
run_repair_partition = PythonOperator(
    task_id="repair_partition",
    dag=dag,
    python_callable=run_add_partitions,
    execution_timeout=timedelta(minutes=10),
    provide_context=True,
)

### V2 with AWSAthenaOperator
run_repair_partition = AWSAthenaOperator(
    task_id='repair_partition',
    query='MSCK REPAIR TABLE amplitude_feed',
    output_location='s3://my-bucket/my-path/',
    database='my_database')

(run_process_data >> run_repair_partition)
Ejemplo n.º 13
0
    check_s3_for_key = S3KeySensor(
        task_id='check_s3_for_key',
        bucket_key=s3_key,
        wildcard_match=True,
        bucket_name=s3_bucket_name,
        s3_conn_id='aws_default',
        timeout=20,
        poke_interval=5,
        dag=dag
    )
    files_to_s3 = PythonOperator(
        task_id="files_to_s3",
        python_callable=download_zip
    )
    
    create_athena_movie_table = AWSAthenaOperator(task_id="create_athena_movie_table",query=create_athena_movie_table_query, database=athena_db, output_location='s3://'+s3_bucket_name+"/"+athena_results+'create_athena_movie_table')
    
    create_athena_ratings_table = AWSAthenaOperator(task_id="create_athena_ratings_table",query=create_athena_ratings_table_query, database=athena_db, output_location='s3://'+s3_bucket_name+"/"+athena_results+'create_athena_ratings_table')
    
    create_athena_tags_table = AWSAthenaOperator(task_id="create_athena_tags_table",query=create_athena_tags_table_query, database=athena_db, output_location='s3://'+s3_bucket_name+"/"+athena_results+'create_athena_tags_table')
    
    join_athena_tables = AWSAthenaOperator(task_id="join_athena_tables",query=join_tables_athena_query, database=athena_db, output_location='s3://'+s3_bucket_name+"/"+athena_results+'join_athena_tables')
    
    create_redshift_table_if_not_exists = PythonOperator(
        task_id="create_redshift_table_if_not_exists",
        python_callable=create_redshift_table
    )

    clean_up_csv = PythonOperator(
        task_id="clean_up_csv",
        python_callable=clean_up_csv_fn,
with DAG(
        dag_id="athena_to_pg_dag_v1",
        description="Commercial DAG",
        start_date=dt.datetime(2019, 7, 29),
        schedule_interval="55 11 * * 1-5",
        default_args=DEFAULT_ARGS,
        catchup=False,
) as dag:

    start_task = DummyOperator(task_id="start_task")

    get_routes_data_task = AWSAthenaOperator(
        task_id="get_routes_data_task",
        aws_conn_id="aws_default",
        query=GET_ROUTES_QUERY,
        database="db_logistics",
        output_location=
        f"s3://gln-airflow/commercial/athena-routes-data/{dt.datetime.now():%Y-%m-%d}",
    )

    load_routes_task = PythonOperator(
        task_id="load_routes_task",
        python_callable=load_athena_to_postgres,
        op_kwargs={
            "p_filename": ROUTE_FILENAME,
            "p_buckpref":
            f"commercial/athena-routes-data/{dt.datetime.now():%Y-%m-%d}",
            "p_staging_table": "sales.transportation_zones_staging",
            "p_target_table": "sales.transportation_zones",
            "p_target_sql": PG_LOAD_ROUTES_SQL,
        },
Ejemplo n.º 15
0
    ##add specific variables
    **conf['dev'],
}

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': datetime.now() - timedelta(minutes=10),
    'email': ['*****@*****.**'],
    'email_on_failure': False,
    'email_on_retry': False,
    "annotations": {
        "iam.amazonaws.com/role": env_vars["AWS_IAM_ARN"]
    }
}

dag = DAG('test_athena',
          default_args=default_args,
          schedule_interval=timedelta(minutes=10),
          catchup=False)

run_query = AWSAthenaOperator(task_id='run_query',
                              query='My Awesome query',
                              aws_conn_id='aws_connection',
                              output_location='S3 path output location',
                              database='myDatabase',
                              workgroup='myworkgroup',
                              retries=4,
                              retry_delay=timedelta(seconds=10),
                              dag=dag)
Ejemplo n.º 16
0
class TestAWSAthenaOperator(unittest.TestCase):
    def setUp(self):
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
            'provide_context': True
        }

        self.dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once',
                       default_args=args,
                       schedule_interval='@once')
        self.athena = AWSAthenaOperator(
            task_id='test_aws_athena_operator',
            query='SELECT * FROM TEST_TABLE',
            database='TEST_DATABASE',
            output_location='s3://test_s3_bucket/',
            client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
            sleep_time=1,
            max_tries=3,
            dag=self.dag)

    def test_init(self):
        self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
        self.assertEqual(self.athena.query, MOCK_DATA['query'])
        self.assertEqual(self.athena.database, MOCK_DATA['database'])
        self.assertEqual(self.athena.aws_conn_id, 'aws_default')
        self.assertEqual(self.athena.client_request_token,
                         MOCK_DATA['client_request_token'])
        self.assertEqual(self.athena.sleep_time, 1)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=("SUCCESS", ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_small_success_query(self, mock_conn, mock_run_query,
                                          mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 1)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "RUNNING",
                           "SUCCESS",
                       ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_big_success_query(self, mock_conn, mock_run_query,
                                        mock_check_query_status):
        self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           None,
                           None,
                       ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query,
                                             mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook, 'get_state_change_reason')
    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "FAILED",
                       ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failure_query(self, mock_conn, mock_run_query,
                                    mock_check_query_status,
                                    mock_get_state_change_reason):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 2)
        self.assertEqual(mock_get_state_change_reason.call_count, 1)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "RUNNING",
                           "CANCELLED",
                       ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_cancelled_query(self, mock_conn, mock_run_query,
                                      mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=(
                           "RUNNING",
                           "RUNNING",
                           "RUNNING",
                       ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_hook_run_failed_query_with_max_tries(self, mock_conn,
                                                  mock_run_query,
                                                  mock_check_query_status):
        with self.assertRaises(Exception):
            self.athena.execute(None)
        mock_run_query.assert_called_once_with(
            MOCK_DATA['query'], query_context, result_configuration,
            MOCK_DATA['client_request_token'], MOCK_DATA['workgroup'])
        self.assertEqual(mock_check_query_status.call_count, 3)

    @mock.patch.object(AWSAthenaHook,
                       'check_query_status',
                       side_effect=("SUCCESS", ))
    @mock.patch.object(AWSAthenaHook,
                       'run_query',
                       return_value=ATHENA_QUERY_ID)
    @mock.patch.object(AWSAthenaHook, 'get_conn')
    def test_xcom_push_and_pull(self, mock_conn, mock_run_query,
                                mock_check_query_status):
        ti = TaskInstance(task=self.athena, execution_date=timezone.utcnow())
        ti.run()

        self.assertEqual(ti.xcom_pull(task_ids='test_aws_athena_operator'),
                         ATHENA_QUERY_ID)
default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2020, 9, 7),
    "email": ["*****@*****.**"],
    "email_on_failure": False,
    "email_on_retry": False,
    "retries": 1,
    "retry_delay": timedelta(minutes=5)
}

with DAG("query_s3", default_args=default_args, schedule_interval= '@once') as dag:

    t1 = BashOperator(
        task_id='bash_test',
        bash_command='echo "Starting AWSAthenaOperator TEST"'
    )

    run_query = AWSAthenaOperator(
        task_id='run_query',
        database='mr-csv',
        query='select text FROM "{DATABASE}"."{TABLE}"',
        output_location='s3://s3://XXX/YYY/ZZZ',
        aws_conn_id='s3_connection'
    )

    
    t1.set_upstream(run_query)