class TestDataflowPythonOperator(unittest.TestCase): def setUp(self): self.dataflow = DataflowCreatePythonJobOperator( task_id=TASK_ID, py_file=PY_FILE, job_name=JOB_NAME, py_options=PY_OPTIONS, dataflow_default_options=DEFAULT_OPTIONS_PYTHON, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, location=TEST_LOCATION, ) def test_init(self): """Test DataFlowPythonOperator instance is properly initialized.""" self.assertEqual(self.dataflow.task_id, TASK_ID) self.assertEqual(self.dataflow.job_name, JOB_NAME) self.assertEqual(self.dataflow.py_file, PY_FILE) self.assertEqual(self.dataflow.py_options, PY_OPTIONS) self.assertEqual(self.dataflow.py_interpreter, PY_INTERPRETER) self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_PYTHON) self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_python_workflow. """ start_python_hook = dataflow_mock.return_value.start_python_dataflow gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) expected_options = { 'project': 'test', 'staging_location': 'gs://test/staging', 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': TEST_VERSION }, } gcs_provide_file.assert_called_once_with(object_url=PY_FILE) start_python_hook.assert_called_once_with( job_name=JOB_NAME, variables=expected_options, dataflow=mock.ANY, py_options=PY_OPTIONS, py_interpreter=PY_INTERPRETER, py_requirements=None, py_system_site_packages=False, on_new_job_id_callback=mock.ANY, project_id=None, location=TEST_LOCATION, ) self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
class TestDataflowPythonOperator(unittest.TestCase): def setUp(self): self.dataflow = DataflowCreatePythonJobOperator( task_id=TASK_ID, py_file=PY_FILE, job_name=JOB_NAME, py_options=PY_OPTIONS, dataflow_default_options=DEFAULT_OPTIONS_PYTHON, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, location=TEST_LOCATION, ) self.expected_airflow_version = 'v' + airflow.version.version.replace( ".", "-").replace("+", "-") def test_init(self): """Test DataFlowPythonOperator instance is properly initialized.""" assert self.dataflow.task_id == TASK_ID assert self.dataflow.job_name == JOB_NAME assert self.dataflow.py_file == PY_FILE assert self.dataflow.py_options == PY_OPTIONS assert self.dataflow.py_interpreter == PY_INTERPRETER assert self.dataflow.poll_sleep == POLL_SLEEP assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_PYTHON assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback' ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id): """Test DataflowHook is created and the right args are passed to start_python_workflow. """ start_python_mock = beam_hook_mock.return_value.start_python_pipeline gcs_provide_file = gcs_hook.return_value.provide_file job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value self.dataflow.execute(None) beam_hook_mock.assert_called_once_with(runner="DataflowRunner") self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) gcs_provide_file.assert_called_once_with(object_url=PY_FILE) mock_callback_on_job_id.assert_called_once_with( on_new_job_id_callback=mock.ANY) dataflow_hook_mock.assert_called_once_with( gcp_conn_id="google_cloud_default", delegate_to=mock.ANY, poll_sleep=POLL_SLEEP, impersonation_chain=None, drain_pipeline=False, cancel_timeout=mock.ANY, wait_until_finished=None, ) expected_options = { "project": dataflow_hook_mock.return_value.project_id, "staging_location": 'gs://test/staging', "job_name": job_name, "region": TEST_LOCATION, 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': self.expected_airflow_version }, } start_python_mock.assert_called_once_with( variables=expected_options, py_file=gcs_provide_file.return_value.__enter__.return_value.name, py_options=PY_OPTIONS, py_interpreter=PY_INTERPRETER, py_requirements=None, py_system_site_packages=False, process_line_callback=mock_callback_on_job_id.return_value, ) dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=False, ) assert self.dataflow.py_file.startswith('/tmp/dataflow')