コード例 #1
0
class TestDataflowJavaOperator(unittest.TestCase):
    def setUp(self):
        self.dataflow = DataflowCreateJavaJobOperator(
            task_id=TASK_ID,
            jar=JAR_FILE,
            job_name=JOB_NAME,
            job_class=JOB_CLASS,
            dataflow_default_options=DEFAULT_OPTIONS_JAVA,
            options=ADDITIONAL_OPTIONS,
            poll_sleep=POLL_SLEEP,
            location=TEST_LOCATION,
        )

    def test_init(self):
        """Test DataflowTemplateOperator instance is properly initialized."""
        self.assertEqual(self.dataflow.task_id, TASK_ID)
        self.assertEqual(self.dataflow.job_name, JOB_NAME)
        self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
        self.assertEqual(self.dataflow.dataflow_default_options,
                         DEFAULT_OPTIONS_JAVA)
        self.assertEqual(self.dataflow.job_class, JOB_CLASS)
        self.assertEqual(self.dataflow.jar, JAR_FILE)
        self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS)
        self.assertEqual(self.dataflow.check_if_running,
                         CheckJobRunning.WaitForRun)

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_exec(self, gcs_hook, dataflow_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow.

        """
        start_java_hook = dataflow_mock.return_value.start_java_dataflow
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = CheckJobRunning.IgnoreJob
        self.dataflow.execute(None)
        self.assertTrue(dataflow_mock.called)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        start_java_hook.assert_called_once_with(
            job_name=JOB_NAME,
            variables=mock.ANY,
            jar=mock.ANY,
            job_class=JOB_CLASS,
            append_job_name=True,
            multiple_jobs=None,
            on_new_job_id_callback=mock.ANY,
            project_id=None,
            location=TEST_LOCATION,
        )

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_job_running_exec(self, gcs_hook, dataflow_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow.

        """
        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
        dataflow_running.return_value = True
        start_java_hook = dataflow_mock.return_value.start_java_dataflow
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = True
        self.dataflow.execute(None)
        self.assertTrue(dataflow_mock.called)
        gcs_provide_file.assert_not_called()
        start_java_hook.assert_not_called()
        dataflow_running.assert_called_once_with(name=JOB_NAME,
                                                 variables=mock.ANY,
                                                 project_id=None,
                                                 location=TEST_LOCATION)

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_job_not_running_exec(self, gcs_hook, dataflow_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow with option to check if job is running

        """
        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
        dataflow_running.return_value = False
        start_java_hook = dataflow_mock.return_value.start_java_dataflow
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = True
        self.dataflow.execute(None)
        self.assertTrue(dataflow_mock.called)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        start_java_hook.assert_called_once_with(
            job_name=JOB_NAME,
            variables=mock.ANY,
            jar=mock.ANY,
            job_class=JOB_CLASS,
            append_job_name=True,
            multiple_jobs=None,
            on_new_job_id_callback=mock.ANY,
            project_id=None,
            location=TEST_LOCATION,
        )
        dataflow_running.assert_called_once_with(name=JOB_NAME,
                                                 variables=mock.ANY,
                                                 project_id=None,
                                                 location=TEST_LOCATION)

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_multiple_job_exec(self, gcs_hook, dataflow_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow with option to check multiple jobs

        """
        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
        dataflow_running.return_value = False
        start_java_hook = dataflow_mock.return_value.start_java_dataflow
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.multiple_jobs = True
        self.dataflow.check_if_running = True
        self.dataflow.execute(None)
        self.assertTrue(dataflow_mock.called)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        start_java_hook.assert_called_once_with(
            job_name=JOB_NAME,
            variables=mock.ANY,
            jar=mock.ANY,
            job_class=JOB_CLASS,
            append_job_name=True,
            multiple_jobs=True,
            on_new_job_id_callback=mock.ANY,
            project_id=None,
            location=TEST_LOCATION,
        )
        dataflow_running.assert_called_once_with(name=JOB_NAME,
                                                 variables=mock.ANY,
                                                 project_id=None,
                                                 location=TEST_LOCATION)
コード例 #2
0
ファイル: test_dataflow.py プロジェクト: ChethanUK/airflow-1
class TestDataflowJavaOperator(unittest.TestCase):
    def setUp(self):
        self.dataflow = DataflowCreateJavaJobOperator(
            task_id=TASK_ID,
            jar=JAR_FILE,
            job_name=JOB_NAME,
            job_class=JOB_CLASS,
            dataflow_default_options=DEFAULT_OPTIONS_JAVA,
            options=ADDITIONAL_OPTIONS,
            poll_sleep=POLL_SLEEP,
            location=TEST_LOCATION,
        )
        self.expected_airflow_version = 'v' + airflow.version.version.replace(
            ".", "-").replace("+", "-")

    def test_init(self):
        """Test DataflowTemplateOperator instance is properly initialized."""
        assert self.dataflow.task_id == TASK_ID
        assert self.dataflow.job_name == JOB_NAME
        assert self.dataflow.poll_sleep == POLL_SLEEP
        assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_JAVA
        assert self.dataflow.job_class == JOB_CLASS
        assert self.dataflow.jar == JAR_FILE
        assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS
        assert self.dataflow.check_if_running == CheckJobRunning.WaitForRun

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
    )
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock,
                  mock_callback_on_job_id):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow.

        """
        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
        self.dataflow.check_if_running = CheckJobRunning.IgnoreJob

        self.dataflow.execute(None)

        mock_callback_on_job_id.assert_called_once_with(
            on_new_job_id_callback=mock.ANY)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        expected_variables = {
            'project': dataflow_hook_mock.return_value.project_id,
            'stagingLocation': 'gs://test/staging',
            'jobName': job_name,
            'region': TEST_LOCATION,
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': self.expected_airflow_version
            },
        }

        start_java_mock.assert_called_once_with(
            variables=expected_variables,
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=mock_callback_on_job_id.return_value,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=mock.ANY,
            job_name=job_name,
            location=TEST_LOCATION,
            multiple_jobs=None,
        )

    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_job_running_exec(self, gcs_hook, dataflow_mock,
                                    beam_hook_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow.

        """
        dataflow_running = dataflow_mock.return_value.is_job_dataflow_running
        dataflow_running.return_value = True
        start_java_hook = beam_hook_mock.return_value.start_java_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = True

        self.dataflow.execute(None)

        self.assertTrue(dataflow_mock.called)
        start_java_hook.assert_not_called()
        gcs_provide_file.assert_called_once()
        variables = {
            'project': dataflow_mock.return_value.project_id,
            'stagingLocation': 'gs://test/staging',
            'jobName': JOB_NAME,
            'region': TEST_LOCATION,
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': self.expected_airflow_version
            },
        }
        dataflow_running.assert_called_once_with(name=JOB_NAME,
                                                 variables=variables)

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
    )
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_job_not_running_exec(self, gcs_hook, dataflow_hook_mock,
                                        beam_hook_mock,
                                        mock_callback_on_job_id):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow with option to check if job is running
        """
        is_job_dataflow_running_variables = None

        def set_is_job_dataflow_running_variables(*args, **kwargs):
            nonlocal is_job_dataflow_running_variables
            is_job_dataflow_running_variables = copy.deepcopy(
                kwargs.get("variables"))

        dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
        dataflow_running.side_effect = set_is_job_dataflow_running_variables
        dataflow_running.return_value = False
        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = True

        self.dataflow.execute(None)

        mock_callback_on_job_id.assert_called_once_with(
            on_new_job_id_callback=mock.ANY)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        expected_variables = {
            'project': dataflow_hook_mock.return_value.project_id,
            'stagingLocation': 'gs://test/staging',
            'jobName': JOB_NAME,
            'region': TEST_LOCATION,
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': self.expected_airflow_version
            },
        }
        self.assertEqual(expected_variables, is_job_dataflow_running_variables)
        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
        expected_variables["jobName"] = job_name
        start_java_mock.assert_called_once_with(
            variables=expected_variables,
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=mock_callback_on_job_id.return_value,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=mock.ANY,
            job_name=job_name,
            location=TEST_LOCATION,
            multiple_jobs=None,
        )

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
    )
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_check_multiple_job_exec(self, gcs_hook, dataflow_hook_mock,
                                     beam_hook_mock, mock_callback_on_job_id):
        """Test DataflowHook is created and the right args are passed to
        start_java_workflow with option to check if job is running
        """
        is_job_dataflow_running_variables = None

        def set_is_job_dataflow_running_variables(*args, **kwargs):
            nonlocal is_job_dataflow_running_variables
            is_job_dataflow_running_variables = copy.deepcopy(
                kwargs.get("variables"))

        dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
        dataflow_running.side_effect = set_is_job_dataflow_running_variables
        dataflow_running.return_value = False
        start_java_mock = beam_hook_mock.return_value.start_java_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.check_if_running = True
        self.dataflow.multiple_jobs = True

        self.dataflow.execute(None)

        mock_callback_on_job_id.assert_called_once_with(
            on_new_job_id_callback=mock.ANY)
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
        expected_variables = {
            'project': dataflow_hook_mock.return_value.project_id,
            'stagingLocation': 'gs://test/staging',
            'jobName': JOB_NAME,
            'region': TEST_LOCATION,
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': self.expected_airflow_version
            },
        }
        self.assertEqual(expected_variables, is_job_dataflow_running_variables)
        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
        expected_variables["jobName"] = job_name
        start_java_mock.assert_called_once_with(
            variables=expected_variables,
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=mock_callback_on_job_id.return_value,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=mock.ANY,
            job_name=job_name,
            location=TEST_LOCATION,
            multiple_jobs=True,
        )