예제 #1
0
    def execute(self, context):
        hook = DataflowHook(gcp_conn_id=self.gcp_conn_id,
                            delegate_to=self.delegate_to,
                            poll_sleep=self.poll_sleep)

        hook.start_template_dataflow(job_name=self.job_name,
                                     variables=self.dataflow_default_options,
                                     parameters=self.parameters,
                                     dataflow_template=self.template)
class TestDataflowTemplateHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'),
                        new=mock_init):
            self.dataflow_hook = DataflowHook(gcp_conn_id='test')

    @mock.patch(DATAFLOW_STRING.format('DataflowHook._start_template_dataflow')
                )
    def test_start_template_dataflow(self, internal_dataflow_mock):
        self.dataflow_hook.start_template_dataflow(
            job_name=JOB_NAME,
            variables=DATAFLOW_OPTIONS_TEMPLATE,
            parameters=PARAMETERS,
            dataflow_template=TEMPLATE)
        options_with_region = {'region': 'us-central1'}
        options_with_region.update(DATAFLOW_OPTIONS_TEMPLATE)
        options_with_region_without_project = copy.deepcopy(
            options_with_region)
        del options_with_region_without_project['project']
        internal_dataflow_mock.assert_called_once_with(
            mock.ANY, options_with_region_without_project, PARAMETERS,
            TEMPLATE, DATAFLOW_OPTIONS_JAVA['project'])

    @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
    @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
    @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
    def test_start_template_dataflow_with_runtime_env(self, mock_conn,
                                                      mock_dataflowjob,
                                                      mock_uuid):
        dataflow_options_template = copy.deepcopy(DATAFLOW_OPTIONS_TEMPLATE)
        options_with_runtime_env = copy.deepcopy(RUNTIME_ENV)
        options_with_runtime_env.update(dataflow_options_template)

        dataflowjob_instance = mock_dataflowjob.return_value
        dataflowjob_instance.wait_for_done.return_value = None
        method = (mock_conn.return_value.projects.return_value.locations.
                  return_value.templates.return_value.launch)

        method.return_value.execute.return_value = {'job': {'id': TEST_JOB_ID}}
        self.dataflow_hook.start_template_dataflow(
            job_name=JOB_NAME,
            variables=options_with_runtime_env,
            parameters=PARAMETERS,
            dataflow_template=TEMPLATE)
        body = {
            "jobName": mock.ANY,
            "parameters": PARAMETERS,
            "environment": RUNTIME_ENV
        }
        method.assert_called_once_with(
            projectId=options_with_runtime_env['project'],
            location='us-central1',
            gcsPath=TEMPLATE,
            body=body,
        )
        mock_dataflowjob.assert_called_once_with(
            dataflow=mock_conn.return_value,
            job_id=TEST_JOB_ID,
            location='us-central1',
            name='test-dataflow-pipeline-{}'.format(MOCK_UUID),
            num_retries=5,
            poll_sleep=10,
            project_number='test')
        mock_uuid.assert_called_once_with()