class TestDataflowJavaOperator(unittest.TestCase): def setUp(self): self.dataflow = DataflowCreateJavaJobOperator( task_id=TASK_ID, jar=JAR_FILE, job_name=JOB_NAME, job_class=JOB_CLASS, dataflow_default_options=DEFAULT_OPTIONS_JAVA, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, location=TEST_LOCATION, ) def test_init(self): """Test DataflowTemplateOperator instance is properly initialized.""" self.assertEqual(self.dataflow.task_id, TASK_ID) self.assertEqual(self.dataflow.job_name, JOB_NAME) self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP) self.assertEqual(self.dataflow.dataflow_default_options, DEFAULT_OPTIONS_JAVA) self.assertEqual(self.dataflow.job_class, JOB_CLASS) self.assertEqual(self.dataflow.jar, JAR_FILE) self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS) self.assertEqual(self.dataflow.check_if_running, CheckJobRunning.WaitForRun) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow. """ start_java_hook = dataflow_mock.return_value.start_java_dataflow gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = CheckJobRunning.IgnoreJob self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, jar=mock.ANY, job_class=JOB_CLASS, append_job_name=True, multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, location=TEST_LOCATION, ) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_running_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow. """ dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = True start_java_hook = dataflow_mock.return_value.start_java_dataflow gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) gcs_provide_file.assert_not_called() start_java_hook.assert_not_called() dataflow_running.assert_called_once_with(name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check if job is running """ dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = False start_java_hook = dataflow_mock.return_value.start_java_dataflow gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, jar=mock.ANY, job_class=JOB_CLASS, append_job_name=True, multiple_jobs=None, on_new_job_id_callback=mock.ANY, project_id=None, location=TEST_LOCATION, ) dataflow_running.assert_called_once_with(name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check multiple jobs """ dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = False start_java_hook = dataflow_mock.return_value.start_java_dataflow gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.multiple_jobs = True self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( job_name=JOB_NAME, variables=mock.ANY, jar=mock.ANY, job_class=JOB_CLASS, append_job_name=True, multiple_jobs=True, on_new_job_id_callback=mock.ANY, project_id=None, location=TEST_LOCATION, ) dataflow_running.assert_called_once_with(name=JOB_NAME, variables=mock.ANY, project_id=None, location=TEST_LOCATION)
class TestDataflowJavaOperator(unittest.TestCase): def setUp(self): self.dataflow = DataflowCreateJavaJobOperator( task_id=TASK_ID, jar=JAR_FILE, job_name=JOB_NAME, job_class=JOB_CLASS, dataflow_default_options=DEFAULT_OPTIONS_JAVA, options=ADDITIONAL_OPTIONS, poll_sleep=POLL_SLEEP, location=TEST_LOCATION, ) self.expected_airflow_version = 'v' + airflow.version.version.replace( ".", "-").replace("+", "-") def test_init(self): """Test DataflowTemplateOperator instance is properly initialized.""" assert self.dataflow.task_id == TASK_ID assert self.dataflow.job_name == JOB_NAME assert self.dataflow.poll_sleep == POLL_SLEEP assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_JAVA assert self.dataflow.job_class == JOB_CLASS assert self.dataflow.jar == JAR_FILE assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS assert self.dataflow.check_if_running == CheckJobRunning.WaitForRun @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback' ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id): """Test DataflowHook is created and the right args are passed to start_java_workflow. """ start_java_mock = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value self.dataflow.check_if_running = CheckJobRunning.IgnoreJob self.dataflow.execute(None) mock_callback_on_job_id.assert_called_once_with( on_new_job_id_callback=mock.ANY) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) expected_variables = { 'project': dataflow_hook_mock.return_value.project_id, 'stagingLocation': 'gs://test/staging', 'jobName': job_name, 'region': TEST_LOCATION, 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': self.expected_airflow_version }, } start_java_mock.assert_called_once_with( variables=expected_variables, jar=gcs_provide_file.return_value.__enter__.return_value.name, job_class=JOB_CLASS, process_line_callback=mock_callback_on_job_id.return_value, ) dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=None, ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_running_exec(self, gcs_hook, dataflow_mock, beam_hook_mock): """Test DataflowHook is created and the right args are passed to start_java_workflow. """ dataflow_running = dataflow_mock.return_value.is_job_dataflow_running dataflow_running.return_value = True start_java_hook = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) self.assertTrue(dataflow_mock.called) start_java_hook.assert_not_called() gcs_provide_file.assert_called_once() variables = { 'project': dataflow_mock.return_value.project_id, 'stagingLocation': 'gs://test/staging', 'jobName': JOB_NAME, 'region': TEST_LOCATION, 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': self.expected_airflow_version }, } dataflow_running.assert_called_once_with(name=JOB_NAME, variables=variables) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback' ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_job_not_running_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check if job is running """ is_job_dataflow_running_variables = None def set_is_job_dataflow_running_variables(*args, **kwargs): nonlocal is_job_dataflow_running_variables is_job_dataflow_running_variables = copy.deepcopy( kwargs.get("variables")) dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running dataflow_running.side_effect = set_is_job_dataflow_running_variables dataflow_running.return_value = False start_java_mock = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.execute(None) mock_callback_on_job_id.assert_called_once_with( on_new_job_id_callback=mock.ANY) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) expected_variables = { 'project': dataflow_hook_mock.return_value.project_id, 'stagingLocation': 'gs://test/staging', 'jobName': JOB_NAME, 'region': TEST_LOCATION, 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': self.expected_airflow_version }, } self.assertEqual(expected_variables, is_job_dataflow_running_variables) job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value expected_variables["jobName"] = job_name start_java_mock.assert_called_once_with( variables=expected_variables, jar=gcs_provide_file.return_value.__enter__.return_value.name, job_class=JOB_CLASS, process_line_callback=mock_callback_on_job_id.return_value, ) dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=None, ) @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback' ) @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') @mock.patch( 'airflow.providers.google.cloud.operators.dataflow.DataflowHook') @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook') def test_check_multiple_job_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, mock_callback_on_job_id): """Test DataflowHook is created and the right args are passed to start_java_workflow with option to check if job is running """ is_job_dataflow_running_variables = None def set_is_job_dataflow_running_variables(*args, **kwargs): nonlocal is_job_dataflow_running_variables is_job_dataflow_running_variables = copy.deepcopy( kwargs.get("variables")) dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running dataflow_running.side_effect = set_is_job_dataflow_running_variables dataflow_running.return_value = False start_java_mock = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file self.dataflow.check_if_running = True self.dataflow.multiple_jobs = True self.dataflow.execute(None) mock_callback_on_job_id.assert_called_once_with( on_new_job_id_callback=mock.ANY) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) expected_variables = { 'project': dataflow_hook_mock.return_value.project_id, 'stagingLocation': 'gs://test/staging', 'jobName': JOB_NAME, 'region': TEST_LOCATION, 'output': 'gs://test/output', 'labels': { 'foo': 'bar', 'airflow-version': self.expected_airflow_version }, } self.assertEqual(expected_variables, is_job_dataflow_running_variables) job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value expected_variables["jobName"] = job_name start_java_mock.assert_called_once_with( variables=expected_variables, jar=gcs_provide_file.return_value.__enter__.return_value.name, job_class=JOB_CLASS, process_line_callback=mock_callback_on_job_id.return_value, ) dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( job_id=mock.ANY, job_name=job_name, location=TEST_LOCATION, multiple_jobs=True, )