def setUp(self): self.operator = BeamRunJavaPipelineOperator( task_id=TASK_ID, jar=JAR_FILE, job_class=JOB_CLASS, default_pipeline_options=DEFAULT_OPTIONS_JAVA, pipeline_options=ADDITIONAL_OPTIONS, )
with models.DAG( "example_gcp_dataflow_native_java", schedule_interval=None, # Override to match your needs start_date=days_ago(1), tags=['example'], ) as dag_native_java: # [START howto_operator_start_java_job_jar_on_gcs] start_java_job = BeamRunJavaPipelineOperator( task_id="start-java-job", jar=GCS_JAR, pipeline_options={ 'output': GCS_OUTPUT, }, job_class='org.apache.beam.examples.WordCount', dataflow_config={ "check_if_running": CheckJobRunning.IgnoreJob, "location": 'europe-west3', "poll_sleep": 10, }, ) # [END howto_operator_start_java_job_jar_on_gcs] # [START howto_operator_start_java_job_local_jar] jar_to_local = GCSToLocalFilesystemOperator( task_id="jar-to-local", bucket=GCS_JAR_BUCKET_NAME, object_name=GCS_JAR_OBJECT_NAME, filename="/tmp/dataflow-{{ ds_nodash }}.jar", )
tags=['example'], ) as dag_native_java_direct_runner: # [START howto_operator_start_java_direct_runner_pipeline] jar_to_local_direct_runner = GCSToLocalFilesystemOperator( task_id="jar_to_local_direct_runner", bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME, object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME, filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", ) start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator( task_id="start_java_pipeline_direct_runner", jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar", pipeline_options={ 'output': '/tmp/start_java_pipeline_direct_runner', 'inputFile': GCS_INPUT, }, job_class='org.apache.beam.examples.WordCount', ) jar_to_local_direct_runner >> start_java_pipeline_direct_runner # [END howto_operator_start_java_direct_runner_pipeline] with models.DAG( "example_beam_native_java_dataflow_runner", schedule_interval=None, # Override to match your needs start_date=days_ago(1), tags=['example'], ) as dag_native_java_dataflow_runner: # [START howto_operator_start_java_dataflow_runner_pipeline]
class TestBeamRunJavaPipelineOperator(unittest.TestCase): def setUp(self): self.operator = BeamRunJavaPipelineOperator( task_id=TASK_ID, jar=JAR_FILE, job_class=JOB_CLASS, default_pipeline_options=DEFAULT_OPTIONS_JAVA, pipeline_options=ADDITIONAL_OPTIONS, ) def test_init(self): """Test BeamRunJavaPipelineOperator instance is properly initialized.""" self.assertEqual(self.operator.task_id, TASK_ID) self.assertEqual(self.operator.runner, DEFAULT_RUNNER) self.assertEqual(self.operator.default_pipeline_options, DEFAULT_OPTIONS_JAVA) self.assertEqual(self.operator.job_class, JOB_CLASS) self.assertEqual(self.operator.jar, JAR_FILE) self.assertEqual(self.operator.pipeline_options, ADDITIONAL_OPTIONS) @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') def test_exec_direct_runner(self, gcs_hook, beam_hook_mock): """Test BeamHook is created and the right args are passed to start_java_workflow. """ start_java_hook = beam_hook_mock.return_value.start_java_pipeline gcs_provide_file = gcs_hook.return_value.provide_file self.operator.execute(None) beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) start_java_hook.assert_called_once_with( variables={ **DEFAULT_OPTIONS_JAVA, **ADDITIONAL_OPTIONS }, jar=gcs_provide_file.return_value.__enter__.return_value.name, job_class=JOB_CLASS, process_line_callback=None, ) @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock): """Test DataflowHook is created and the right args are passed to start_java_dataflow. """ dataflow_config = DataflowConfiguration() self.operator.runner = "DataflowRunner" self.operator.dataflow_config = dataflow_config gcs_provide_file = gcs_hook.return_value.provide_file dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False self.operator.execute(None) job_name = dataflow_hook_mock.build_dataflow_job_name.return_value self.assertEqual(job_name, self.operator._dataflow_job_name) dataflow_hook_mock.assert_called_once_with( gcp_conn_id=dataflow_config.gcp_conn_id, delegate_to=dataflow_config.delegate_to, poll_sleep=dataflow_config.poll_sleep, impersonation_chain=dataflow_config.impersonation_chain, drain_pipeline=dataflow_config.drain_pipeline, cancel_timeout=dataflow_config.cancel_timeout, wait_until_finished=dataflow_config.wait_until_finished, ) gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) expected_options = { 'project': dataflow_hook_mock.return_value.project_id, 'jobName': job_name, 'stagingLocation': 'gs://test/staging', 'region': 'us-central1', 'labels': { 'foo': 'bar', 'airflow-version': TEST_VERSION }, 'output': 'gs://test/output', } beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with( variables=expected_options, jar=gcs_provide_file.return_value.__enter__.return_value.name, job_class=JOB_CLASS, process_line_callback=mock.ANY, ) dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( job_id=self.operator.dataflow_job_id, job_name=job_name, location='us-central1', multiple_jobs=dataflow_config.multiple_jobs, project_id=dataflow_hook_mock.return_value.project_id, ) @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __): self.operator.runner = "DataflowRunner" dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job self.operator.execute(None) self.operator.dataflow_job_id = JOB_ID self.operator.on_kill() dataflow_cancel_job.assert_called_once_with( job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id) @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook') @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook') @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook') def test_on_kill_direct_runner(self, _, dataflow_mock, __): dataflow_cancel_job = dataflow_mock.return_value.cancel_job self.operator.execute(None) self.operator.on_kill() dataflow_cancel_job.assert_not_called()