def test_bigquery_operator_extra_link_when_single_query(self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() session.query(XCom).delete() ti = TaskInstance( task=bigquery_task, execution_date=DEFAULT_DATE, ) job_id = '12345' ti.xcom_push(key='job_id', value=job_id) self.assertEqual( 'https://console.cloud.google.com/bigquery?j={job_id}'.format(job_id=job_id), bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name), ) self.assertEqual( '', bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name), )
def test_bigquery_operator_extra_link(self, mock_hook): bigquery_task = BigQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', dag=self.dag, ) self.dag.clear() ti = TaskInstance( task=bigquery_task, execution_date=DEFAULT_DATE, ) job_id = '12345' ti.xcom_push(key='job_id', value=job_id) self.assertEquals( 'https://console.cloud.google.com/bigquery?j={job_id}'.format(job_id=job_id), bigquery_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name), ) self.assertEquals( '', bigquery_task.get_extra_links(datetime(2019, 1, 1), BigQueryConsoleLink.name), )
def test_set_checkpoint_no_current_checkpoint_prefetch_has_data_true( env, bigquery_helper, seed): table = 'lake.tree_users' seeds = [('system', [('checkpoint', [])])] seed(seeds) task_id = 'set_checkpoint_no_current_record' with DAG(dag_id='set_checkpoint_test', start_date=datetime.now()) as dag: task = SetCheckpointOperator(env=env['env'], table=table, dag=dag, task_id=task_id) ti = TaskInstance(task=task, execution_date=datetime.now()) ti.xcom_push(key=table, value={ 'first_ingestion_timestamp': '1970-01-01 00:00:00+00:00', 'last_ingestion_timestamp': '2020-03-27 06:05:00+00:00', 'has_data': True }) task.execute(ti.get_template_context()) rs = bigquery_helper.query( f"SELECT * FROM {env['project']}.system.checkpoint WHERE table = '{table}'" ) assert str(rs[0]['checkpoint']) == '2020-03-27 06:05:00+00:00'
def test_bigquery_operator_extra_link_when_multiple_query(self, mock_hook, session): bigquery_task = BigQueryExecuteQueryOperator( task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], dag=self.dag, ) self.dag.clear() session.query(XCom).delete() ti = TaskInstance( task=bigquery_task, execution_date=DEFAULT_DATE, ) job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) self.assertEqual( {'BigQuery Console #1', 'BigQuery Console #2'}, bigquery_task.operator_extra_link_dict.keys() ) self.assertEqual( 'https://console.cloud.google.com/bigquery?j=123', bigquery_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #1'), ) self.assertEqual( 'https://console.cloud.google.com/bigquery?j=45', bigquery_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #2'), )
def test_xcom_pull_after_success(self): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator( task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() # The second run and assert is to handle AIRFLOW-131 (don't clear on # prior success) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't # execute, even if dependencies are ignored ti.run(ignore_all_deps=True, mark_success=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) # Xcom IS finally cleared once task has executed ti.run(ignore_all_deps=True) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None)
def test_xcom_pull_different_execution_date(self): """ tests xcom fetch behavior with different execution dates, using both xcom_pull with "include_prior_dates" and without """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator( task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() exec_date += datetime.timedelta(days=1) ti = TI( task=task, execution_date=exec_date) ti.run() # We have set a new execution date (and did not pass in # 'include_prior_dates'which means this task should now have a cleared # xcom value self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), None) # We *should* get a value using 'include_prior_dates' self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True), value)
def test_console_extra_link_serialized_field(self): with self.dag: training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS) serialized_dag = SerializedDAG.to_dict(self.dag) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']] # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}], ) # Check DeSerialized version of operator link self.assertIsInstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] gcp_metadata = { "job_id": job_id, "project_id": project_id, } ti = TaskInstance(task=training_op, execution_date=DEFAULT_DATE,) ti.xcom_push(key='gcp_metadata', value=gcp_metadata) self.assertEqual( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}", simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name), ) self.assertEqual( '', simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), )
def test_console_extra_link(self, mock_hook): training_op = MLEngineStartTrainingJobOperator( **self.TRAINING_DEFAULT_ARGS) ti = TaskInstance( task=training_op, execution_date=DEFAULT_DATE, ) job_id = self.TRAINING_DEFAULT_ARGS['job_id'] project_id = self.TRAINING_DEFAULT_ARGS['project_id'] gcp_metadata = { "job_id": job_id, "project_id": project_id, } ti.xcom_push(key='gcp_metadata', value=gcp_metadata) self.assertEqual( f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}", training_op.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name), ) self.assertEqual( '', training_op.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name), )
def test_bigquery_operator_extra_serialized_field_when_multiple_queries( self): with self.dag: BigQueryExecuteQueryOperator( task_id=TASK_ID, sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], ) serialized_dag = SerializedDAG.to_dict(self.dag) self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] self.assertEqual( getattr(simple_task, "sql"), ['SELECT * FROM test_table', 'SELECT * FROM test_table2']) ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [ { 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { 'index': 0 } }, { 'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink': { 'index': 1 } }, ], ) # Check DeSerialized version of operator link self.assertIsInstance( list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink) ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) job_id = ['123', '45'] ti.xcom_push(key='job_id', value=job_id) self.assertEqual({'BigQuery Console #1', 'BigQuery Console #2'}, simple_task.operator_extra_link_dict.keys()) self.assertEqual( 'https://console.cloud.google.com/bigquery?j=123', simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #1'), ) self.assertEqual( 'https://console.cloud.google.com/bigquery?j=45', simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #2'), )
def test_bigquery_operator_extra_serialized_field_when_single_query(self): with self.dag: BigQueryExecuteQueryOperator( task_id=TASK_ID, sql='SELECT * FROM test_table', ) serialized_dag = SerializedDAG.to_dict(self.dag) assert "sql" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict[TASK_ID] assert getattr(simple_task, "sql") == 'SELECT * FROM test_table' ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ {'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}} ] # Check DeSerialized version of operator link assert isinstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink) ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) ti.xcom_push('job_id', 12345) # check for positive case url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) assert url == 'https://console.cloud.google.com/bigquery?j=12345' # check for negative case url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name) assert url2 == ''
def _return_via_ti(ti: TaskInstance) -> None: """Pass the value generated randomly to xcom via specified key. Args: ti: the task instance """ value = randint(1, 1_000) ti.xcom_push(key=_KEY, value=value) print(f"Value pushed to xcom: {value}") return
def _generate_value(ti: TaskInstance) -> None: """Generate value between 1 and 1_000, inclusive. Args: ti: the task instance """ value = randint(1, 1_000) ti.xcom_push(key="generated_value", value=value) print(f"Value pushed to xcom: {value}") return
def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None: if sys.getsizeof(data) < MAX_XCOM_SIZE: task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data) else: raise RuntimeError( 'The size of the downloaded data is too large to push to XCom!' )
def test_extra_serialized_field_and_multiple_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. This tests also depends on GoogleLink() registered as a plugin in tests/plugins/test_plugin.py The function tests that if extra operator links are registered in plugin in ``operator_extra_links`` and the same is also defined in the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ test_date = datetime(2019, 8, 1) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"]) serialized_dag = SerializedDAG.to_dict(dag) assert "bash_command" in serialized_dag["dag"]["tasks"][0] dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] assert getattr(simple_task, "bash_command") == ["echo", "true"] ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link only contains the inbuilt Op Link assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [ {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 0}}, {'tests.test_utils.mock_operators.CustomBaseIndexOpLink': {'index': 1}}, ] # Test all the extra_links are set assert set(simple_task.extra_links) == { 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google', } ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) # Test Deserialized inbuilt link #1 custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1") assert 'https://console.cloud.google.com/bigquery?j=dummy_value_1' == custom_inbuilt_link # Test Deserialized inbuilt link #2 custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #2") assert 'https://console.cloud.google.com/bigquery?j=dummy_value_2' == custom_inbuilt_link # Test Deserialized link registered via Airflow Plugin google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name) assert "https://www.google.com" == google_link_from_plugin
def test_extra_serialized_field_and_operator_links(self): """ Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. This tests also depends on GoogleLink() registered as a plugin in tests/plugins/test_plugin.py The function tests that if extra operator links are registered in plugin in ``operator_extra_links`` and the same is also defined in the Operator in ``BaseOperator.operator_extra_links``, it has the correct extra link. """ test_date = datetime(2019, 8, 1) dag = DAG(dag_id='simple_dag', start_date=test_date) CustomOperator(task_id='simple_task', dag=dag, bash_command="true") serialized_dag = SerializedDAG.to_dict(dag) self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0]) dag = SerializedDAG.from_dict(serialized_dag) simple_task = dag.task_dict["simple_task"] self.assertEqual(getattr(simple_task, "bash_command"), "true") ######################################################### # Verify Operator Links work with Serialized Operator ######################################################### # Check Serialized version of operator link only contains the inbuilt Op Link self.assertEqual( serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], [{ 'tests.test_utils.mock_operators.CustomOpLink': {} }], ) # Test all the extra_links are set self.assertCountEqual(simple_task.extra_links, ['Google Custom', 'airflow', 'github', 'google']) ti = TaskInstance(task=simple_task, execution_date=test_date) ti.xcom_push('search_query', "dummy_value_1") # Test Deserialized inbuilt link custom_inbuilt_link = simple_task.get_extra_links( test_date, CustomOpLink.name) self.assertEqual( 'http://google.com/custom_base_link?search=dummy_value_1', custom_inbuilt_link) # Test Deserialized link registered via Airflow Plugin google_link_from_plugin = simple_task.get_extra_links( test_date, GoogleLink.name) self.assertEqual("https://www.google.com", google_link_from_plugin)
def test_render_template_2(self): dag = DAG( dag_id='test_xcom', default_args=self.args) xcom_steps = [ { 'Name': 'test_step1', 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': { 'Jar': 'command-runner.jar', 'Args': [ '/usr/lib/spark/bin/run-example1' ] } }, { 'Name': 'test_step2', 'ActionOnFailure': 'CONTINUE', 'HadoopJarStep': { 'Jar': 'command-runner.jar', 'Args': [ '/usr/lib/spark/bin/run-example2' ] } } ] make_steps = DummyOperator(task_id='make_steps', dag=dag, owner='airflow') execution_date = timezone.utcnow() ti1 = TaskInstance(task=make_steps, execution_date=execution_date) ti1.xcom_push(key='steps', value=xcom_steps) self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN test_task = EmrAddStepsOperator( task_id='test_task', job_flow_id='j-8989898989', aws_conn_id='aws_default', steps="{{ ti.xcom_pull(task_ids='make_steps',key='steps') }}", dag=dag) with patch('boto3.session.Session', self.boto3_session_mock): ti = TaskInstance(task=test_task, execution_date=execution_date) ti.run() self.emr_client_mock.add_job_flow_steps.assert_called_once_with( JobFlowId='j-8989898989', Steps=xcom_steps)
def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) task_instance = TaskInstance(task=task, execution_date=execution_date) task_instance.xcom_push = mock.Mock() return { "dag": dag, "ts": execution_date.isoformat(), "task": task, "ti": task_instance, "task_instance": task_instance, }
def test_xcom_pull_after_success(self): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator(task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI(task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() # The second run and assert is to handle AIRFLOW-131 (don't clear on # prior success) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
def test_xcom_pull(self): """ Test xcom_pull, using different filtering methods. """ dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly', start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) exec_date = timezone.utcnow() # Push a value task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow') ti1 = TI(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow') ti2 = TI(task=task2, execution_date=exec_date) ti2.xcom_push(key='foo', value='baz') # Pull with no arguments result = ti1.xcom_pull() self.assertEqual(result, None) # Pull the value pushed most recently by any task. result = ti1.xcom_pull(key='foo') self.assertIn(result, 'baz') # Pull the value pushed by the first task result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo') self.assertEqual(result, 'bar') # Pull the value pushed by the second task result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo') self.assertEqual(result, 'baz') # Pull the values pushed by both tasks result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') self.assertEqual(result, ('bar', 'baz'))
def test_xcom_pull_after_success(self): """ tests xcom set/clear relative to a task in a 'success' rerun scenario """ key = 'xcom_key' value = 'xcom_value' dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly') task = DummyOperator( task_id='test_xcom', dag=dag, pool='test_xcom', owner='airflow', start_date=datetime.datetime(2016, 6, 2, 0, 0, 0)) exec_date = datetime.datetime.now() ti = TI( task=task, execution_date=exec_date) ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) ti.run() # The second run and assert is to handle AIRFLOW-131 (don't clear on # prior success) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value)
def test_xcom_pull(self): """ Test xcom_pull, using different filtering methods. """ dag = models.DAG( dag_id='test_xcom', schedule_interval='@monthly', start_date=timezone.datetime(2016, 6, 1, 0, 0, 0)) exec_date = timezone.utcnow() # Push a value task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow') ti1 = TI(task=task1, execution_date=exec_date) ti1.xcom_push(key='foo', value='bar') # Push another value with the same key (but by a different task) task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow') ti2 = TI(task=task2, execution_date=exec_date) ti2.xcom_push(key='foo', value='baz') # Pull with no arguments result = ti1.xcom_pull() self.assertEqual(result, None) # Pull the value pushed most recently by any task. result = ti1.xcom_pull(key='foo') self.assertIn(result, 'baz') # Pull the value pushed by the first task result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo') self.assertEqual(result, 'bar') # Pull the value pushed by the second task result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo') self.assertEqual(result, 'baz') # Pull the values pushed by both tasks result = ti1.xcom_pull( task_ids=['test_xcom_1', 'test_xcom_2'], key='foo') self.assertEqual(result, ('bar', 'baz'))
def create_context(task): dag = DAG(dag_id="dag") tzinfo = pendulum.timezone("Europe/Amsterdam") execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo) dag_run = DagRun( dag_id=dag.dag_id, execution_date=execution_date, run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), ) task_instance = TaskInstance(task=task) task_instance.dag_run = dag_run task_instance.dag_id = dag.dag_id task_instance.xcom_push = mock.Mock() return { "dag": dag, "run_id": dag_run.run_id, "task": task, "ti": task_instance, "task_instance": task_instance, }