示例#1
0
 def execute(self, context):
     """Execute the python dataflow job."""
     bucket_helper = GoogleCloudBucketHelper(self.gcp_conn_id,
                                             self.delegate_to)
     self.py_file = bucket_helper.google_cloud_to_local(self.py_file)
     hook = DataflowHook(gcp_conn_id=self.gcp_conn_id,
                         delegate_to=self.delegate_to,
                         poll_sleep=self.poll_sleep)
     dataflow_options = self.dataflow_default_options.copy()
     dataflow_options.update(self.options)
     # Convert argument names from lowerCamelCase to snake case.
     camel_to_snake = lambda name: re.sub(
         r'[A-Z]', lambda x: '_' + x.group(0).lower(), name)
     formatted_options = {
         camel_to_snake(key): dataflow_options[key]
         for key in dataflow_options
     }
     hook.start_python_dataflow(
         job_name=self.job_name,
         variables=formatted_options,
         dataflow=self.py_file,
         py_options=self.py_options,
         py_interpreter=self.py_interpreter,
         py_requirements=self.py_requirements,
         py_system_site_packages=self.py_system_site_packages,
     )
示例#2
0
    def execute(self, context):
        hook = DataflowHook(gcp_conn_id=self.gcp_conn_id,
                            delegate_to=self.delegate_to,
                            poll_sleep=self.poll_sleep)

        hook.start_template_dataflow(job_name=self.job_name,
                                     variables=self.dataflow_default_options,
                                     parameters=self.parameters,
                                     dataflow_template=self.template)
示例#3
0
    def execute(self, context):
        hook = DataflowHook(gcp_conn_id=self.gcp_conn_id,
                            delegate_to=self.delegate_to,
                            poll_sleep=self.poll_sleep)
        dataflow_options = copy.copy(self.dataflow_default_options)
        dataflow_options.update(self.options)
        is_running = False
        if self.check_if_running != CheckJobRunning.IgnoreJob:
            is_running = hook.is_job_dataflow_running(
                name=self.job_name,
                variables=dataflow_options
            )
            while is_running and self.check_if_running == CheckJobRunning.WaitForRun:
                is_running = hook.is_job_dataflow_running(name=self.job_name, variables=dataflow_options)

        if not is_running:
            bucket_helper = GoogleCloudBucketHelper(
                self.gcp_conn_id, self.delegate_to)
            self.jar = bucket_helper.google_cloud_to_local(self.jar)
            hook.start_java_dataflow(
                job_name=self.job_name,
                variables=dataflow_options,
                jar=self.jar,
                job_class=self.job_class,
                append_job_name=True,
                multiple_jobs=self.multiple_jobs
            )
 def setUp(self):
     with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'),
                     new=mock_init):
         self.dataflow_hook = DataflowHook(gcp_conn_id='test')
class TestDataflowTemplateHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'),
                        new=mock_init):
            self.dataflow_hook = DataflowHook(gcp_conn_id='test')

    @mock.patch(DATAFLOW_STRING.format('DataflowHook._start_template_dataflow')
                )
    def test_start_template_dataflow(self, internal_dataflow_mock):
        self.dataflow_hook.start_template_dataflow(
            job_name=JOB_NAME,
            variables=DATAFLOW_OPTIONS_TEMPLATE,
            parameters=PARAMETERS,
            dataflow_template=TEMPLATE)
        options_with_region = {'region': 'us-central1'}
        options_with_region.update(DATAFLOW_OPTIONS_TEMPLATE)
        options_with_region_without_project = copy.deepcopy(
            options_with_region)
        del options_with_region_without_project['project']
        internal_dataflow_mock.assert_called_once_with(
            mock.ANY, options_with_region_without_project, PARAMETERS,
            TEMPLATE, DATAFLOW_OPTIONS_JAVA['project'])

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_template_dataflow_with_runtime_env(self, mock_conn,
                                                      mock_dataflowjob,
                                                      mock_uuid):
        dataflow_options_template = copy.deepcopy(DATAFLOW_OPTIONS_TEMPLATE)
        options_with_runtime_env = copy.deepcopy(RUNTIME_ENV)
        options_with_runtime_env.update(dataflow_options_template)

        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        method = (mock_conn.return_value.projects.return_value.locations.
                  return_value.templates.return_value.launch)

        method.return_value.execute.return_value = {'job': {'id': TEST_JOB_ID}}
        self.dataflow_hook.start_template_dataflow(
            job_name=JOB_NAME,
            variables=options_with_runtime_env,
            parameters=PARAMETERS,
            dataflow_template=TEMPLATE)
        body = {
            "jobName": mock.ANY,
            "parameters": PARAMETERS,
            "environment": RUNTIME_ENV
        }
        method.assert_called_once_with(
            projectId=options_with_runtime_env['project'],
            location='us-central1',
            gcsPath=TEMPLATE,
            body=body,
        )
        mock_dataflowjob.assert_called_once_with(
            dataflow=mock_conn.return_value,
            job_id=TEST_JOB_ID,
            location='us-central1',
            name='test-dataflow-pipeline-{}'.format(MOCK_UUID),
            num_retries=5,
            poll_sleep=10,
            project_number='test')
        mock_uuid.assert_called_once_with()
class TestDataflowHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'),
                        new=mock_init):
            self.dataflow_hook = DataflowHook(gcp_conn_id='test')

    @mock.patch("airflow.gcp.hooks.dataflow.DataflowHook._authorize")
    @mock.patch("airflow.gcp.hooks.dataflow.build")
    def test_dataflow_client_creation(self, mock_build, mock_authorize):
        result = self.dataflow_hook.get_conn()
        mock_build.assert_called_once_with('dataflow',
                                           'v1b3',
                                           http=mock_authorize.return_value,
                                           cache_discovery=False)
        self.assertEqual(mock_build.return_value, result)

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_python_dataflow(self, mock_conn, mock_dataflow,
                                   mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME,
                                                 variables=DATAFLOW_OPTIONS_PY,
                                                 dataflow=PY_FILE,
                                                 py_options=PY_OPTIONS)
        expected_cmd = [
            "python3", '-m', PY_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test', '--labels=foo=bar',
            '--staging_location=gs://test/staging',
            '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @parameterized.expand([('default_to_python3', 'python3'),
                           ('major_version_2', 'python2'),
                           ('major_version_3', 'python3'),
                           ('minor_version', 'python3.6')])
    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_python_dataflow_with_custom_interpreter(
            self, name, py_interpreter, mock_conn, mock_dataflow,
            mock_dataflowjob, mock_uuid):
        del name  # unused variable
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME,
                                                 variables=DATAFLOW_OPTIONS_PY,
                                                 dataflow=PY_FILE,
                                                 py_options=PY_OPTIONS,
                                                 py_interpreter=py_interpreter)
        expected_cmd = [
            py_interpreter, '-m', PY_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test', '--labels=foo=bar',
            '--staging_location=gs://test/staging',
            '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_java_dataflow(self, mock_conn, mock_dataflow,
                                 mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE)
        expected_cmd = [
            'java', '-jar', JAR_FILE, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow,
                                                mock_dataflowjob, mock_uuid):
        mock_uuid.return_value = MOCK_UUID
        mock_conn.return_value = None
        dataflow_instance = mock_dataflow.return_value
        dataflow_instance.wait_for_done.return_value = None
        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME,
                                               variables=DATAFLOW_OPTIONS_JAVA,
                                               jar=JAR_FILE,
                                               job_class=JOB_CLASS)
        expected_cmd = [
            'java', '-cp', JAR_FILE, JOB_CLASS, '--region=us-central1',
            '--runner=DataflowRunner', '--project=test',
            '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}',
            '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)
        ]
        self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
                             sorted(expected_cmd))

    @parameterized.expand([
        (JOB_NAME, JOB_NAME, False),
        ('test-example', 'test_example', False),
        ('test-dataflow-pipeline-12345678', JOB_NAME, True),
        ('test-example-12345678', 'test_example', True),
        ('df-job-1', 'df-job-1', False),
        ('df-job', 'df-job', False),
        ('dfjob', 'dfjob', False),
        ('dfjob1', 'dfjob1', False),
    ])
    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
    def test_valid_dataflow_job_name(self, expected_result, job_name,
                                     append_job_name, mock_uuid4):
        job_name = self.dataflow_hook._build_dataflow_job_name(
            job_name=job_name, append_job_name=append_job_name)

        self.assertEqual(expected_result, job_name)

    @parameterized.expand([("1dfjob@", ), ("dfjob@", ), ("df^jo", )])
    def test_build_dataflow_job_name_with_invalid_value(self, job_name):
        self.assertRaises(ValueError,
                          self.dataflow_hook._build_dataflow_job_name,
                          job_name=job_name,
                          append_job_name=False)