def execute(self, context): hook = DataflowHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep) hook.start_template_dataflow(job_name=self.job_name, variables=self.dataflow_default_options, parameters=self.parameters, dataflow_template=self.template)
class TestDataflowTemplateHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'), new=mock_init): self.dataflow_hook = DataflowHook(gcp_conn_id='test') @mock.patch(DATAFLOW_STRING.format('DataflowHook._start_template_dataflow') ) def test_start_template_dataflow(self, internal_dataflow_mock): self.dataflow_hook.start_template_dataflow( job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_TEMPLATE, parameters=PARAMETERS, dataflow_template=TEMPLATE) options_with_region = {'region': 'us-central1'} options_with_region.update(DATAFLOW_OPTIONS_TEMPLATE) options_with_region_without_project = copy.deepcopy( options_with_region) del options_with_region_without_project['project'] internal_dataflow_mock.assert_called_once_with( mock.ANY, options_with_region_without_project, PARAMETERS, TEMPLATE, DATAFLOW_OPTIONS_JAVA['project']) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflowjob, mock_uuid): dataflow_options_template = copy.deepcopy(DATAFLOW_OPTIONS_TEMPLATE) options_with_runtime_env = copy.deepcopy(RUNTIME_ENV) options_with_runtime_env.update(dataflow_options_template) dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None method = (mock_conn.return_value.projects.return_value.locations. return_value.templates.return_value.launch) method.return_value.execute.return_value = {'job': {'id': TEST_JOB_ID}} self.dataflow_hook.start_template_dataflow( job_name=JOB_NAME, variables=options_with_runtime_env, parameters=PARAMETERS, dataflow_template=TEMPLATE) body = { "jobName": mock.ANY, "parameters": PARAMETERS, "environment": RUNTIME_ENV } method.assert_called_once_with( projectId=options_with_runtime_env['project'], location='us-central1', gcsPath=TEMPLATE, body=body, ) mock_dataflowjob.assert_called_once_with( dataflow=mock_conn.return_value, job_id=TEST_JOB_ID, location='us-central1', name='test-dataflow-pipeline-{}'.format(MOCK_UUID), num_retries=5, poll_sleep=10, project_number='test') mock_uuid.assert_called_once_with()