Example #1
0
 def setUp(self):
     self.operator = BeamRunJavaPipelineOperator(
         task_id=TASK_ID,
         jar=JAR_FILE,
         job_class=JOB_CLASS,
         default_pipeline_options=DEFAULT_OPTIONS_JAVA,
         pipeline_options=ADDITIONAL_OPTIONS,
     )
Example #2
0
with models.DAG(
        "example_gcp_dataflow_native_java",
        schedule_interval=None,  # Override to match your needs
        start_date=days_ago(1),
        tags=['example'],
) as dag_native_java:

    # [START howto_operator_start_java_job_jar_on_gcs]
    start_java_job = BeamRunJavaPipelineOperator(
        task_id="start-java-job",
        jar=GCS_JAR,
        pipeline_options={
            'output': GCS_OUTPUT,
        },
        job_class='org.apache.beam.examples.WordCount',
        dataflow_config={
            "check_if_running": CheckJobRunning.IgnoreJob,
            "location": 'europe-west3',
            "poll_sleep": 10,
        },
    )
    # [END howto_operator_start_java_job_jar_on_gcs]

    # [START howto_operator_start_java_job_local_jar]
    jar_to_local = GCSToLocalFilesystemOperator(
        task_id="jar-to-local",
        bucket=GCS_JAR_BUCKET_NAME,
        object_name=GCS_JAR_OBJECT_NAME,
        filename="/tmp/dataflow-{{ ds_nodash }}.jar",
    )
        tags=['example'],
) as dag_native_java_direct_runner:

    # [START howto_operator_start_java_direct_runner_pipeline]
    jar_to_local_direct_runner = GCSToLocalFilesystemOperator(
        task_id="jar_to_local_direct_runner",
        bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME,
        object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME,
        filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
    )

    start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator(
        task_id="start_java_pipeline_direct_runner",
        jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
        pipeline_options={
            'output': '/tmp/start_java_pipeline_direct_runner',
            'inputFile': GCS_INPUT,
        },
        job_class='org.apache.beam.examples.WordCount',
    )

    jar_to_local_direct_runner >> start_java_pipeline_direct_runner
    # [END howto_operator_start_java_direct_runner_pipeline]

with models.DAG(
        "example_beam_native_java_dataflow_runner",
        schedule_interval=None,  # Override to match your needs
        start_date=days_ago(1),
        tags=['example'],
) as dag_native_java_dataflow_runner:
    # [START howto_operator_start_java_dataflow_runner_pipeline]
Example #4
0
class TestBeamRunJavaPipelineOperator(unittest.TestCase):
    def setUp(self):
        self.operator = BeamRunJavaPipelineOperator(
            task_id=TASK_ID,
            jar=JAR_FILE,
            job_class=JOB_CLASS,
            default_pipeline_options=DEFAULT_OPTIONS_JAVA,
            pipeline_options=ADDITIONAL_OPTIONS,
        )

    def test_init(self):
        """Test BeamRunJavaPipelineOperator instance is properly initialized."""
        self.assertEqual(self.operator.task_id, TASK_ID)
        self.assertEqual(self.operator.runner, DEFAULT_RUNNER)
        self.assertEqual(self.operator.default_pipeline_options,
                         DEFAULT_OPTIONS_JAVA)
        self.assertEqual(self.operator.job_class, JOB_CLASS)
        self.assertEqual(self.operator.jar, JAR_FILE)
        self.assertEqual(self.operator.pipeline_options, ADDITIONAL_OPTIONS)

    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
    def test_exec_direct_runner(self, gcs_hook, beam_hook_mock):
        """Test BeamHook is created and the right args are passed to
        start_java_workflow.
        """
        start_java_hook = beam_hook_mock.return_value.start_java_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.operator.execute(None)

        beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        start_java_hook.assert_called_once_with(
            variables={
                **DEFAULT_OPTIONS_JAVA,
                **ADDITIONAL_OPTIONS
            },
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=None,
        )

    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock,
                                  beam_hook_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_dataflow.
        """
        dataflow_config = DataflowConfiguration()
        self.operator.runner = "DataflowRunner"
        self.operator.dataflow_config = dataflow_config
        gcs_provide_file = gcs_hook.return_value.provide_file
        dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
        self.operator.execute(None)
        job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
        self.assertEqual(job_name, self.operator._dataflow_job_name)
        dataflow_hook_mock.assert_called_once_with(
            gcp_conn_id=dataflow_config.gcp_conn_id,
            delegate_to=dataflow_config.delegate_to,
            poll_sleep=dataflow_config.poll_sleep,
            impersonation_chain=dataflow_config.impersonation_chain,
            drain_pipeline=dataflow_config.drain_pipeline,
            cancel_timeout=dataflow_config.cancel_timeout,
            wait_until_finished=dataflow_config.wait_until_finished,
        )
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)

        expected_options = {
            'project': dataflow_hook_mock.return_value.project_id,
            'jobName': job_name,
            'stagingLocation': 'gs://test/staging',
            'region': 'us-central1',
            'labels': {
                'foo': 'bar',
                'airflow-version': TEST_VERSION
            },
            'output': 'gs://test/output',
        }

        beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
            variables=expected_options,
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=mock.ANY,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=self.operator.dataflow_job_id,
            job_name=job_name,
            location='us-central1',
            multiple_jobs=dataflow_config.multiple_jobs,
            project_id=dataflow_hook_mock.return_value.project_id,
        )

    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
    def test_on_kill_dataflow_runner(self, dataflow_hook_mock, _, __):
        self.operator.runner = "DataflowRunner"
        dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
        dataflow_cancel_job = dataflow_hook_mock.return_value.cancel_job
        self.operator.execute(None)
        self.operator.dataflow_job_id = JOB_ID
        self.operator.on_kill()
        dataflow_cancel_job.assert_called_once_with(
            job_id=JOB_ID, project_id=self.operator.dataflow_config.project_id)

    @mock.patch('airflow.providers.apache.beam.operators.beam.BeamHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.DataflowHook')
    @mock.patch('airflow.providers.apache.beam.operators.beam.GCSHook')
    def test_on_kill_direct_runner(self, _, dataflow_mock, __):
        dataflow_cancel_job = dataflow_mock.return_value.cancel_job
        self.operator.execute(None)
        self.operator.on_kill()
        dataflow_cancel_job.assert_not_called()