def execute(self, context): hook = DataflowHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep) dataflow_options = copy.copy(self.dataflow_default_options) dataflow_options.update(self.options) is_running = False if self.check_if_running != CheckJobRunning.IgnoreJob: is_running = hook.is_job_dataflow_running( name=self.job_name, variables=dataflow_options ) while is_running and self.check_if_running == CheckJobRunning.WaitForRun: is_running = hook.is_job_dataflow_running(name=self.job_name, variables=dataflow_options) if not is_running: bucket_helper = GoogleCloudBucketHelper( self.gcp_conn_id, self.delegate_to) self.jar = bucket_helper.google_cloud_to_local(self.jar) hook.start_java_dataflow( job_name=self.job_name, variables=dataflow_options, jar=self.jar, job_class=self.job_class, append_job_name=True, multiple_jobs=self.multiple_jobs )
class DataflowCreateJavaJobOperator(BaseOperator): """ Start a Java Cloud DataFlow batch job. The parameters of the operation will be passed to the job. **Example**: :: default_args = { 'owner': 'airflow', 'depends_on_past': False, 'start_date': (2016, 8, 1), 'email': ['*****@*****.**'], 'email_on_failure': False, 'email_on_retry': False, 'retries': 1, 'retry_delay': timedelta(minutes=30), 'dataflow_default_options': { 'project': 'my-gcp-project', 'zone': 'us-central1-f', 'stagingLocation': 'gs://bucket/tmp/dataflow/staging/', } } dag = DAG('test-dag', default_args=default_args) task = DataFlowJavaOperator( gcp_conn_id='gcp_default', task_id='normalize-cal', jar='{{var.value.gcp_dataflow_base}}pipeline-ingress-cal-normalize-1.0.jar', options={ 'autoscalingAlgorithm': 'BASIC', 'maxNumWorkers': '50', 'start': '{{ds}}', 'partitionType': 'DAY' }, dag=dag) .. seealso:: For more detail on job submission have a look at the reference: https://cloud.google.com/dataflow/pipelines/specifying-exec-params :param jar: The reference to a self executing DataFlow jar (templated). :type jar: str :param job_name: The 'jobName' to use when executing the DataFlow job (templated). This ends up being set in the pipeline options, so any entry with key ``'jobName'`` in ``options`` will be overwritten. :type job_name: str :param dataflow_default_options: Map of default job options. :type dataflow_default_options: dict :param options: Map of job specific options.The key must be a dictionary. The value can contain different types: * If the value is None, the single option - ``--key`` (without value) will be added. * If the value is False, this option will be skipped * If the value is True, the single option - ``--key`` (without value) will be added. * If the value is list, the many options will be added for each key. If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options will be left * Other value types will be replaced with the Python textual representation. When defining labels (``labels`` option), you can also provide a dictionary. :type options: dict :param project_id: Optional, the Google Cloud project ID in which to start a job. If set to None or missing, the default project_id from the Google Cloud connection is used. :type project_id: str :param location: Job location. :type location: str :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :type gcp_conn_id: str :param delegate_to: The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service account making the request must have domain-wide delegation enabled. :type delegate_to: str :param poll_sleep: The time in seconds to sleep between polling Google Cloud Platform for the dataflow job status while the job is in the JOB_STATE_RUNNING state. :type poll_sleep: int :param job_class: The name of the dataflow job class to be executed, it is often not the main class configured in the dataflow jar file. :type job_class: str :param multiple_jobs: If pipeline creates multiple jobs then monitor all jobs :type multiple_jobs: boolean :param check_if_running: before running job, validate that a previous run is not in process :type check_if_running: CheckJobRunning(IgnoreJob = do not check if running, FinishIfRunning= if job is running finish with nothing, WaitForRun= wait until job finished and the run job) ``jar``, ``options``, and ``job_name`` are templated so you can use variables in them. Note that both ``dataflow_default_options`` and ``options`` will be merged to specify pipeline execution parameter, and ``dataflow_default_options`` is expected to save high-level options, for instances, project and zone information, which apply to all dataflow operators in the DAG. It's a good practice to define dataflow_* parameters in the default_args of the dag like the project, zone and staging location. .. code-block:: python default_args = { 'dataflow_default_options': { 'zone': 'europe-west1-d', 'stagingLocation': 'gs://my-staging-bucket/staging/' } } You need to pass the path to your dataflow as a file reference with the ``jar`` parameter, the jar needs to be a self executing jar (see documentation here: https://beam.apache.org/documentation/runners/dataflow/#self-executing-jar). Use ``options`` to pass on options to your job. .. code-block:: python t1 = DataFlowJavaOperator( task_id='dataflow_example', jar='{{var.value.gcp_dataflow_base}}pipeline/build/libs/pipeline-example-1.0.jar', options={ 'autoscalingAlgorithm': 'BASIC', 'maxNumWorkers': '50', 'start': '{{ds}}', 'partitionType': 'DAY', 'labels': {'foo' : 'bar'} }, gcp_conn_id='airflow-conn-id', dag=my-dag) """ template_fields = ['options', 'jar', 'job_name'] ui_color = '#0273d4' # pylint: disable=too-many-arguments @apply_defaults def __init__( self, *, jar: str, job_name: str = '{{task.task_id}}', dataflow_default_options: Optional[dict] = None, options: Optional[dict] = None, project_id: Optional[str] = None, location: str = DEFAULT_DATAFLOW_LOCATION, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, poll_sleep: int = 10, job_class: Optional[str] = None, check_if_running: CheckJobRunning = CheckJobRunning.WaitForRun, multiple_jobs: Optional[bool] = None, **kwargs, ) -> None: super().__init__(**kwargs) dataflow_default_options = dataflow_default_options or {} options = options or {} options.setdefault('labels', {}).update({ 'airflow-version': 'v' + version.replace('.', '-').replace('+', '-') }) self.project_id = project_id self.location = location self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.jar = jar self.multiple_jobs = multiple_jobs self.job_name = job_name self.dataflow_default_options = dataflow_default_options self.options = options self.poll_sleep = poll_sleep self.job_class = job_class self.check_if_running = check_if_running self.job_id = None self.hook = None def execute(self, context): self.hook = DataflowHook(gcp_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to, poll_sleep=self.poll_sleep) dataflow_options = copy.copy(self.dataflow_default_options) dataflow_options.update(self.options) is_running = False if self.check_if_running != CheckJobRunning.IgnoreJob: is_running = self.hook.is_job_dataflow_running( name=self.job_name, variables=dataflow_options, project_id=self.project_id, location=self.location, ) while is_running and self.check_if_running == CheckJobRunning.WaitForRun: is_running = self.hook.is_job_dataflow_running( name=self.job_name, variables=dataflow_options, project_id=self.project_id, location=self.location, ) if not is_running: with ExitStack() as exit_stack: if self.jar.lower().startswith('gs://'): gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to) tmp_gcs_file = exit_stack.enter_context( # pylint: disable=no-member gcs_hook.provide_file(object_url=self.jar)) self.jar = tmp_gcs_file.name def set_current_job_id(job_id): self.job_id = job_id self.hook.start_java_dataflow( job_name=self.job_name, variables=dataflow_options, jar=self.jar, job_class=self.job_class, append_job_name=True, multiple_jobs=self.multiple_jobs, on_new_job_id_callback=set_current_job_id, project_id=self.project_id, location=self.location, ) def on_kill(self) -> None: self.log.info("On kill.") if self.job_id: self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
class TestDataflowHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format('CloudBaseHook.__init__'), new=mock_init): self.dataflow_hook = DataflowHook(gcp_conn_id='test') @mock.patch( "airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize" ) @mock.patch("airflow.providers.google.cloud.hooks.dataflow.build") def test_dataflow_client_creation(self, mock_build, mock_authorize): result = self.dataflow_hook.get_conn() mock_build.assert_called_once_with('dataflow', 'v1b3', http=mock_authorize.return_value, cache_discovery=False) self.assertEqual(mock_build.return_value, result) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY, dataflow=PY_FILE, py_options=PY_OPTIONS) expected_cmd = [ "python3", '-m', PY_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID) ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @parameterized.expand([('default_to_python3', 'python3'), ('major_version_2', 'python2'), ('major_version_3', 'python3'), ('minor_version', 'python3.6')]) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_custom_interpreter( self, name, py_interpreter, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): del name # unused variable mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow(job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, py_interpreter=py_interpreter) expected_cmd = [ py_interpreter, '-m', PY_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID) ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_JAVA, jar=JAR_FILE) expected_cmd = [ 'java', '-jar', JAR_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID) ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow(job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_JAVA, jar=JAR_FILE, job_class=JOB_CLASS) expected_cmd = [ 'java', '-cp', JAR_FILE, JOB_CLASS, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID) ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @parameterized.expand([ (JOB_NAME, JOB_NAME, False), ('test-example', 'test_example', False), ('test-dataflow-pipeline-12345678', JOB_NAME, True), ('test-example-12345678', 'test_example', True), ('df-job-1', 'df-job-1', False), ('df-job', 'df-job', False), ('dfjob', 'dfjob', False), ('dfjob1', 'dfjob1', False), ]) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4): job_name = self.dataflow_hook._build_dataflow_job_name( job_name=job_name, append_job_name=append_job_name) self.assertEqual(expected_result, job_name) @parameterized.expand([("1dfjob@", ), ("dfjob@", ), ("df^jo", )]) def test_build_dataflow_job_name_with_invalid_value(self, job_name): self.assertRaises(ValueError, self.dataflow_hook._build_dataflow_job_name, job_name=job_name, append_job_name=False)
class TestDataflowHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init): self.dataflow_hook = DataflowHook(gcp_conn_id='test') @mock.patch( "airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize" ) @mock.patch("airflow.providers.google.cloud.hooks.dataflow.build") def test_dataflow_client_creation(self, mock_build, mock_authorize): result = self.dataflow_hook.get_conn() mock_build.assert_called_once_with('dataflow', 'v1b3', http=mock_authorize.return_value, cache_discovery=False) self.assertEqual(mock_build.return_value, result) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, ) expected_cmd = [ "python3", '-m', PY_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_custom_region_as_variable( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None variables = copy.deepcopy(DATAFLOW_VARIABLES_PY) variables['region'] = TEST_LOCATION self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=variables, dataflow=PY_FILE, py_options=PY_OPTIONS, ) expected_cmd = [ "python3", '-m', PY_FILE, f'--region={TEST_LOCATION}', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_custom_region_as_paramater( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, location=TEST_LOCATION, ) expected_cmd = [ "python3", '-m', PY_FILE, f'--region={TEST_LOCATION}', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_multiple_extra_packages( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_PY) variables['extra-package'] = ['a.whl', 'b.whl'] self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=variables, dataflow=PY_FILE, py_options=PY_OPTIONS, ) expected_cmd = [ "python3", '-m', PY_FILE, '--extra-package=a.whl', '--extra-package=b.whl', '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @parameterized.expand([ ('default_to_python3', 'python3'), ('major_version_2', 'python2'), ('major_version_3', 'python3'), ('minor_version', 'python3.6'), ]) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_custom_interpreter( self, name, py_interpreter, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid, ): del name # unused variable mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, py_interpreter=py_interpreter, ) expected_cmd = [ py_interpreter, '-m', PY_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @parameterized.expand([ (['foo-bar'], False), (['foo-bar'], True), ([], True), ]) @mock.patch(DATAFLOW_STRING.format('prepare_virtualenv')) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_non_empty_py_requirements_and_without_system_packages( self, current_py_requirements, current_py_system_site_packages, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid, mock_virtualenv, ): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None mock_virtualenv.return_value = '/dummy_dir/bin/python' self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, py_requirements=current_py_requirements, py_system_site_packages=current_py_system_site_packages, ) expected_cmd = [ '/dummy_dir/bin/python', '-m', PY_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--labels=foo=bar', '--staging_location=gs://test/staging', '--job_name={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_python_dataflow_with_empty_py_requirements_and_without_system_packages( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None with self.assertRaisesRegex(AirflowException, "Invalid method invocation."): self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY, dataflow=PY_FILE, py_options=PY_OPTIONS, py_requirements=[], ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE) expected_cmd = [ 'java', '-jar', JAR_FILE, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual( sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow_with_multiple_values_in_variables( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA) variables['mock-option'] = ['a.whl', 'b.whl'] self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=variables, jar=JAR_FILE) expected_cmd = [ 'java', '-jar', JAR_FILE, '--mock-option=a.whl', '--mock-option=b.whl', '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow_with_custom_region_as_variable( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA) variables['region'] = TEST_LOCATION self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=variables, jar=JAR_FILE) expected_cmd = [ 'java', '-jar', JAR_FILE, f'--region={TEST_LOCATION}', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual( sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow_with_custom_region_as_parameter( self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None variables = copy.deepcopy(DATAFLOW_VARIABLES_JAVA) variables['region'] = TEST_LOCATION self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=variables, jar=JAR_FILE) expected_cmd = [ 'java', '-jar', JAR_FILE, f'--region={TEST_LOCATION}', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual( sorted(expected_cmd), sorted(mock_dataflow.call_args[1]["cmd"]), ) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4')) @mock.patch(DATAFLOW_STRING.format('_DataflowJobsController')) @mock.patch(DATAFLOW_STRING.format('_DataflowRunner')) @mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn')) def test_start_java_dataflow_with_job_class(self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid): mock_uuid.return_value = MOCK_UUID mock_conn.return_value = None dataflow_instance = mock_dataflow.return_value dataflow_instance.wait_for_done.return_value = None dataflowjob_instance = mock_dataflowjob.return_value dataflowjob_instance.wait_for_done.return_value = None self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA, jar=JAR_FILE, job_class=JOB_CLASS) expected_cmd = [ 'java', '-cp', JAR_FILE, JOB_CLASS, '--region=us-central1', '--runner=DataflowRunner', '--project=test', '--stagingLocation=gs://test/staging', '--labels={"foo":"bar"}', '--jobName={}-{}'.format(JOB_NAME, MOCK_UUID), ] self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd)) @parameterized.expand([ (JOB_NAME, JOB_NAME, False), ('test-example', 'test_example', False), ('test-dataflow-pipeline-12345678', JOB_NAME, True), ('test-example-12345678', 'test_example', True), ('df-job-1', 'df-job-1', False), ('df-job', 'df-job', False), ('dfjob', 'dfjob', False), ('dfjob1', 'dfjob1', False), ]) @mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID) def test_valid_dataflow_job_name(self, expected_result, job_name, append_job_name, mock_uuid4): job_name = self.dataflow_hook._build_dataflow_job_name( job_name=job_name, append_job_name=append_job_name) self.assertEqual(expected_result, job_name) @parameterized.expand([("1dfjob@", ), ("dfjob@", ), ("df^jo", )]) def test_build_dataflow_job_name_with_invalid_value(self, job_name): self.assertRaises(ValueError, self.dataflow_hook._build_dataflow_job_name, job_name=job_name, append_job_name=False)