Exemplo n.º 1
0
class TestDataflowPythonOperator(unittest.TestCase):
    def setUp(self):
        self.dataflow = DataflowCreatePythonJobOperator(
            task_id=TASK_ID,
            py_file=PY_FILE,
            job_name=JOB_NAME,
            py_options=PY_OPTIONS,
            dataflow_default_options=DEFAULT_OPTIONS_PYTHON,
            options=ADDITIONAL_OPTIONS,
            poll_sleep=POLL_SLEEP,
            location=TEST_LOCATION,
        )

    def test_init(self):
        """Test DataFlowPythonOperator instance is properly initialized."""
        self.assertEqual(self.dataflow.task_id, TASK_ID)
        self.assertEqual(self.dataflow.job_name, JOB_NAME)
        self.assertEqual(self.dataflow.py_file, PY_FILE)
        self.assertEqual(self.dataflow.py_options, PY_OPTIONS)
        self.assertEqual(self.dataflow.py_interpreter, PY_INTERPRETER)
        self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
        self.assertEqual(self.dataflow.dataflow_default_options,
                         DEFAULT_OPTIONS_PYTHON)
        self.assertEqual(self.dataflow.options, EXPECTED_ADDITIONAL_OPTIONS)

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_exec(self, gcs_hook, dataflow_mock):
        """Test DataflowHook is created and the right args are passed to
        start_python_workflow.

        """
        start_python_hook = dataflow_mock.return_value.start_python_dataflow
        gcs_provide_file = gcs_hook.return_value.provide_file
        self.dataflow.execute(None)
        self.assertTrue(dataflow_mock.called)
        expected_options = {
            'project': 'test',
            'staging_location': 'gs://test/staging',
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': TEST_VERSION
            },
        }
        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
        start_python_hook.assert_called_once_with(
            job_name=JOB_NAME,
            variables=expected_options,
            dataflow=mock.ANY,
            py_options=PY_OPTIONS,
            py_interpreter=PY_INTERPRETER,
            py_requirements=None,
            py_system_site_packages=False,
            on_new_job_id_callback=mock.ANY,
            project_id=None,
            location=TEST_LOCATION,
        )
        self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
Exemplo n.º 2
0
class TestDataflowPythonOperator(unittest.TestCase):
    def setUp(self):
        self.dataflow = DataflowCreatePythonJobOperator(
            task_id=TASK_ID,
            py_file=PY_FILE,
            job_name=JOB_NAME,
            py_options=PY_OPTIONS,
            dataflow_default_options=DEFAULT_OPTIONS_PYTHON,
            options=ADDITIONAL_OPTIONS,
            poll_sleep=POLL_SLEEP,
            location=TEST_LOCATION,
        )
        self.expected_airflow_version = 'v' + airflow.version.version.replace(
            ".", "-").replace("+", "-")

    def test_init(self):
        """Test DataFlowPythonOperator instance is properly initialized."""
        assert self.dataflow.task_id == TASK_ID
        assert self.dataflow.job_name == JOB_NAME
        assert self.dataflow.py_file == PY_FILE
        assert self.dataflow.py_options == PY_OPTIONS
        assert self.dataflow.py_interpreter == PY_INTERPRETER
        assert self.dataflow.poll_sleep == POLL_SLEEP
        assert self.dataflow.dataflow_default_options == DEFAULT_OPTIONS_PYTHON
        assert self.dataflow.options == EXPECTED_ADDITIONAL_OPTIONS

    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.process_line_and_extract_dataflow_job_id_callback'
    )
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
    @mock.patch(
        'airflow.providers.google.cloud.operators.dataflow.DataflowHook')
    @mock.patch('airflow.providers.google.cloud.operators.dataflow.GCSHook')
    def test_exec(self, gcs_hook, dataflow_hook_mock, beam_hook_mock,
                  mock_callback_on_job_id):
        """Test DataflowHook is created and the right args are passed to
        start_python_workflow.

        """
        start_python_mock = beam_hook_mock.return_value.start_python_pipeline
        gcs_provide_file = gcs_hook.return_value.provide_file
        job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
        self.dataflow.execute(None)
        beam_hook_mock.assert_called_once_with(runner="DataflowRunner")
        self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
        gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
        mock_callback_on_job_id.assert_called_once_with(
            on_new_job_id_callback=mock.ANY)
        dataflow_hook_mock.assert_called_once_with(
            gcp_conn_id="google_cloud_default",
            delegate_to=mock.ANY,
            poll_sleep=POLL_SLEEP,
            impersonation_chain=None,
            drain_pipeline=False,
            cancel_timeout=mock.ANY,
            wait_until_finished=None,
        )
        expected_options = {
            "project": dataflow_hook_mock.return_value.project_id,
            "staging_location": 'gs://test/staging',
            "job_name": job_name,
            "region": TEST_LOCATION,
            'output': 'gs://test/output',
            'labels': {
                'foo': 'bar',
                'airflow-version': self.expected_airflow_version
            },
        }
        start_python_mock.assert_called_once_with(
            variables=expected_options,
            py_file=gcs_provide_file.return_value.__enter__.return_value.name,
            py_options=PY_OPTIONS,
            py_interpreter=PY_INTERPRETER,
            py_requirements=None,
            py_system_site_packages=False,
            process_line_callback=mock_callback_on_job_id.return_value,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=mock.ANY,
            job_name=job_name,
            location=TEST_LOCATION,
            multiple_jobs=False,
        )
        assert self.dataflow.py_file.startswith('/tmp/dataflow')