コード例 #1
0
ファイル: beam.py プロジェクト: dskoda1/airflow
    def __init__(
        self,
        *,
        runner: str = "DirectRunner",
        default_pipeline_options: Optional[dict] = None,
        pipeline_options: Optional[dict] = None,
        gcp_conn_id: str = "google_cloud_default",
        delegate_to: Optional[str] = None,
        dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.runner = runner
        self.default_pipeline_options = default_pipeline_options or {}
        self.pipeline_options = pipeline_options or {}
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        if isinstance(dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(**dataflow_config)
        else:
            self.dataflow_config = dataflow_config or DataflowConfiguration()
        self.beam_hook: Optional[BeamHook] = None
        self.dataflow_hook: Optional[DataflowHook] = None
        self.dataflow_job_id: Optional[str] = None

        if self.dataflow_config and self.runner.lower(
        ) != BeamRunnerType.DataflowRunner.lower():
            self.log.warning(
                "dataflow_config is defined but runner is different than DataflowRunner (%s)",
                self.runner)
コード例 #2
0
ファイル: beam.py プロジェクト: kushsharma/airflow
    def __init__(
        self,
        *,
        py_file: str,
        runner: str = "DirectRunner",
        default_pipeline_options: Optional[dict] = None,
        pipeline_options: Optional[dict] = None,
        py_interpreter: str = "python3",
        py_options: Optional[List[str]] = None,
        py_requirements: Optional[List[str]] = None,
        py_system_site_packages: bool = False,
        gcp_conn_id: str = "google_cloud_default",
        delegate_to: Optional[str] = None,
        dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        self.py_file = py_file
        self.runner = runner
        self.py_options = py_options or []
        self.default_pipeline_options = default_pipeline_options or {}
        self.pipeline_options = pipeline_options or {}
        self.pipeline_options.setdefault("labels", {}).update({
            "airflow-version":
            "v" + version.replace(".", "-").replace("+", "-")
        })
        self.py_interpreter = py_interpreter
        self.py_requirements = py_requirements
        self.py_system_site_packages = py_system_site_packages
        self.gcp_conn_id = gcp_conn_id
        self.delegate_to = delegate_to
        self.beam_hook: Optional[BeamHook] = None
        self.dataflow_hook: Optional[DataflowHook] = None
        self.dataflow_job_id: Optional[str] = None

        if dataflow_config is None:
            self.dataflow_config = DataflowConfiguration()
        elif isinstance(dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(**dataflow_config)
        else:
            self.dataflow_config = dataflow_config

        if self.dataflow_config and self.runner.lower(
        ) != BeamRunnerType.DataflowRunner.lower():
            self.log.warning(
                "dataflow_config is defined but runner is different than DataflowRunner (%s)",
                self.runner)
コード例 #3
0
    def execute(self, context):
        """Execute the Apache Beam Pipeline."""
        self.beam_hook = BeamHook(runner=self.runner)
        pipeline_options = self.default_pipeline_options.copy()
        process_line_callback: Optional[Callable] = None
        is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
        dataflow_job_name: Optional[str] = None

        if isinstance(self.dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(**self.dataflow_config)

        if is_dataflow:
            dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
                pipeline_options=pipeline_options, job_name_variable_key="job_name"
            )

        pipeline_options.update(self.pipeline_options)

        # Convert argument names from lowerCamelCase to snake case.
        formatted_pipeline_options = {
            convert_camel_to_snake(key): pipeline_options[key] for key in pipeline_options
        }

        with ExitStack() as exit_stack:
            if self.py_file.lower().startswith("gs://"):
                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
                tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
                self.py_file = tmp_gcs_file.name

            if is_dataflow:
                with self.dataflow_hook.provide_authorized_gcloud():
                    self.beam_hook.start_python_pipeline(
                        variables=formatted_pipeline_options,
                        py_file=self.py_file,
                        py_options=self.py_options,
                        py_interpreter=self.py_interpreter,
                        py_requirements=self.py_requirements,
                        py_system_site_packages=self.py_system_site_packages,
                        process_line_callback=process_line_callback,
                    )

                self.dataflow_hook.wait_for_done(
                    job_name=dataflow_job_name,
                    location=self.dataflow_config.location,
                    job_id=self.dataflow_job_id,
                    multiple_jobs=False,
                )

            else:
                self.beam_hook.start_python_pipeline(
                    variables=formatted_pipeline_options,
                    py_file=self.py_file,
                    py_options=self.py_options,
                    py_interpreter=self.py_interpreter,
                    py_requirements=self.py_requirements,
                    py_system_site_packages=self.py_system_site_packages,
                    process_line_callback=process_line_callback,
                )

        return {"dataflow_job_id": self.dataflow_job_id}
コード例 #4
0
    def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock,
                                  beam_hook_mock):
        """Test DataflowHook is created and the right args are passed to
        start_java_dataflow.
        """
        dataflow_config = DataflowConfiguration()
        self.operator.runner = "DataflowRunner"
        self.operator.dataflow_config = dataflow_config
        gcs_provide_file = gcs_hook.return_value.provide_file
        dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
        self.operator.execute(None)
        job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
        self.assertEqual(job_name, self.operator._dataflow_job_name)
        dataflow_hook_mock.assert_called_once_with(
            gcp_conn_id=dataflow_config.gcp_conn_id,
            delegate_to=dataflow_config.delegate_to,
            poll_sleep=dataflow_config.poll_sleep,
            impersonation_chain=dataflow_config.impersonation_chain,
            drain_pipeline=dataflow_config.drain_pipeline,
            cancel_timeout=dataflow_config.cancel_timeout,
            wait_until_finished=dataflow_config.wait_until_finished,
        )
        gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)

        expected_options = {
            'project': dataflow_hook_mock.return_value.project_id,
            'jobName': job_name,
            'stagingLocation': 'gs://test/staging',
            'region': 'us-central1',
            'labels': {
                'foo': 'bar',
                'airflow-version': TEST_VERSION
            },
            'output': 'gs://test/output',
        }

        beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
            variables=expected_options,
            jar=gcs_provide_file.return_value.__enter__.return_value.name,
            job_class=JOB_CLASS,
            process_line_callback=mock.ANY,
        )
        dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
            job_id=self.operator.dataflow_job_id,
            job_name=job_name,
            location='us-central1',
            multiple_jobs=dataflow_config.multiple_jobs,
            project_id=dataflow_hook_mock.return_value.project_id,
        )
コード例 #5
0
 def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock,
                               beam_hook_mock):
     """Test DataflowHook is created and the right args are passed to
     start_python_dataflow.
     """
     dataflow_config = DataflowConfiguration()
     self.operator.runner = "DataflowRunner"
     self.operator.dataflow_config = dataflow_config
     gcs_provide_file = gcs_hook.return_value.provide_file
     self.operator.execute(None)
     job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
     dataflow_hook_mock.assert_called_once_with(
         gcp_conn_id=dataflow_config.gcp_conn_id,
         delegate_to=dataflow_config.delegate_to,
         poll_sleep=dataflow_config.poll_sleep,
         impersonation_chain=dataflow_config.impersonation_chain,
         drain_pipeline=dataflow_config.drain_pipeline,
         cancel_timeout=dataflow_config.cancel_timeout,
         wait_until_finished=dataflow_config.wait_until_finished,
     )
     expected_options = {
         'project': dataflow_hook_mock.return_value.project_id,
         'job_name': job_name,
         'staging_location': 'gs://test/staging',
         'output': 'gs://test/output',
         'labels': {
             'foo': 'bar',
             'airflow-version': TEST_VERSION
         },
         'region': 'us-central1',
     }
     gcs_provide_file.assert_called_once_with(object_url=PY_FILE)
     beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with(
         variables=expected_options,
         py_file=gcs_provide_file.return_value.__enter__.return_value.name,
         py_options=PY_OPTIONS,
         py_interpreter=PY_INTERPRETER,
         py_requirements=None,
         py_system_site_packages=False,
         process_line_callback=mock.ANY,
     )
     dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
         job_id=self.operator.dataflow_job_id,
         job_name=job_name,
         location='us-central1',
         multiple_jobs=False,
     )
コード例 #6
0
    # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
    start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator(
        task_id="start_python_pipeline_dataflow_runner",
        runner="DataflowRunner",
        py_file=GCS_PYTHON,
        pipeline_options={
            'tempLocation': GCS_TMP,
            'stagingLocation': GCS_STAGING,
            'output': GCS_OUTPUT,
        },
        py_options=[],
        py_requirements=['apache-beam[gcp]==2.26.0'],
        py_interpreter='python3',
        py_system_site_packages=False,
        dataflow_config=DataflowConfiguration(job_name='{{task.task_id}}',
                                              project_id=GCP_PROJECT_ID,
                                              location="us-central1"),
    )
    # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]

    start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator(
        task_id="start_python_pipeline_local_spark_runner",
        py_file='apache_beam.examples.wordcount',
        runner="SparkRunner",
        py_options=['-m'],
        py_requirements=['apache-beam[gcp]==2.26.0'],
        py_interpreter='python3',
        py_system_site_packages=False,
    )

    start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator(
コード例 #7
0
    def execute(self, context):
        """Execute the Apache Beam Pipeline."""
        self.beam_hook = BeamHook(runner=self.runner)
        pipeline_options = self.default_pipeline_options.copy()
        process_line_callback: Optional[Callable] = None
        is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
        dataflow_job_name: Optional[str] = None

        if isinstance(self.dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(**self.dataflow_config)

        if is_dataflow:
            dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
                pipeline_options=pipeline_options, job_name_variable_key=None
            )

        pipeline_options.update(self.pipeline_options)

        with ExitStack() as exit_stack:
            if self.jar.lower().startswith("gs://"):
                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
                tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
                self.jar = tmp_gcs_file.name

            if is_dataflow:
                is_running = False
                if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
                    is_running = (
                        # The reason for disable=no-value-for-parameter is that project_id parameter is
                        # required but here is not passed, moreover it cannot be passed here.
                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
                        # fallback project_id value from variables and raise error if project_id is
                        # defined both in variables and as parameter (here is already defined in variables)
                        self.dataflow_hook.is_job_dataflow_running(
                            name=self.dataflow_config.job_name,
                            variables=pipeline_options,
                        )
                    )
                    while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
                        # The reason for disable=no-value-for-parameter is that project_id parameter is
                        # required but here is not passed, moreover it cannot be passed here.
                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
                        # fallback project_id value from variables and raise error if project_id is
                        # defined both in variables and as parameter (here is already defined in variables)

                        is_running = self.dataflow_hook.is_job_dataflow_running(
                            name=self.dataflow_config.job_name,
                            variables=pipeline_options,
                        )
                if not is_running:
                    pipeline_options["jobName"] = dataflow_job_name
                    with self.dataflow_hook.provide_authorized_gcloud():
                        self.beam_hook.start_java_pipeline(
                            variables=pipeline_options,
                            jar=self.jar,
                            job_class=self.job_class,
                            process_line_callback=process_line_callback,
                        )
                    self.dataflow_hook.wait_for_done(
                        job_name=dataflow_job_name,
                        location=self.dataflow_config.location,
                        job_id=self.dataflow_job_id,
                        multiple_jobs=self.dataflow_config.multiple_jobs,
                        project_id=self.dataflow_config.project_id,
                    )

            else:
                self.beam_hook.start_java_pipeline(
                    variables=pipeline_options,
                    jar=self.jar,
                    job_class=self.job_class,
                    process_line_callback=process_line_callback,
                )

        return {"dataflow_job_id": self.dataflow_job_id}
コード例 #8
0
    def execute(self, context):
        """Execute the Apache Beam Pipeline."""
        self.beam_hook = BeamHook(runner=self.runner)
        pipeline_options = self.default_pipeline_options.copy()
        process_line_callback: Optional[Callable] = None
        is_dataflow = self.runner.lower(
        ) == BeamRunnerType.DataflowRunner.lower()

        if isinstance(self.dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(
                **self.dataflow_config)

        if is_dataflow:
            self.dataflow_hook = DataflowHook(
                gcp_conn_id=self.dataflow_config.gcp_conn_id
                or self.gcp_conn_id,
                delegate_to=self.dataflow_config.delegate_to
                or self.delegate_to,
                poll_sleep=self.dataflow_config.poll_sleep,
                impersonation_chain=self.dataflow_config.impersonation_chain,
                drain_pipeline=self.dataflow_config.drain_pipeline,
                cancel_timeout=self.dataflow_config.cancel_timeout,
                wait_until_finished=self.dataflow_config.wait_until_finished,
            )
            self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id

            self._dataflow_job_name = DataflowHook.build_dataflow_job_name(
                self.dataflow_config.job_name,
                self.dataflow_config.append_job_name)
            pipeline_options["jobName"] = self.dataflow_config.job_name
            pipeline_options["project"] = self.dataflow_config.project_id
            pipeline_options["region"] = self.dataflow_config.location
            pipeline_options.setdefault("labels", {}).update({
                "airflow-version":
                "v" + version.replace(".", "-").replace("+", "-")
            })

            def set_current_dataflow_job_id(job_id):
                self.dataflow_job_id = job_id

            process_line_callback = process_line_and_extract_dataflow_job_id_callback(
                on_new_job_id_callback=set_current_dataflow_job_id)

        pipeline_options.update(self.pipeline_options)

        with ExitStack() as exit_stack:
            if self.jar.lower().startswith("gs://"):
                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
                tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
                    gcs_hook.provide_file(object_url=self.jar))
                self.jar = tmp_gcs_file.name

            if is_dataflow:
                is_running = False
                if self.dataflow_config.check_if_running != CheckJobRunning.IgnoreJob:
                    is_running = (
                        # The reason for disable=no-value-for-parameter is that project_id parameter is
                        # required but here is not passed, moreover it cannot be passed here.
                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
                        # fallback project_id value from variables and raise error if project_id is
                        # defined both in variables and as parameter (here is already defined in variables)
                        self.dataflow_hook.is_job_dataflow_running(  # pylint: disable=no-value-for-parameter
                            name=self.dataflow_config.job_name,
                            variables=pipeline_options,
                        ))
                    while is_running and self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun:
                        # The reason for disable=no-value-for-parameter is that project_id parameter is
                        # required but here is not passed, moreover it cannot be passed here.
                        # This method is wrapped by @_fallback_to_project_id_from_variables decorator which
                        # fallback project_id value from variables and raise error if project_id is
                        # defined both in variables and as parameter (here is already defined in variables)
                        # pylint: disable=no-value-for-parameter
                        is_running = self.dataflow_hook.is_job_dataflow_running(
                            name=self.dataflow_config.job_name,
                            variables=pipeline_options,
                        )
                if not is_running:
                    pipeline_options["jobName"] = self._dataflow_job_name
                    self.beam_hook.start_java_pipeline(
                        variables=pipeline_options,
                        jar=self.jar,
                        job_class=self.job_class,
                        process_line_callback=process_line_callback,
                    )
                    self.dataflow_hook.wait_for_done(
                        job_name=self._dataflow_job_name,
                        location=self.dataflow_config.location,
                        job_id=self.dataflow_job_id,
                        multiple_jobs=self.dataflow_config.multiple_jobs,
                        project_id=self.dataflow_config.project_id,
                    )

            else:
                self.beam_hook.start_java_pipeline(
                    variables=pipeline_options,
                    jar=self.jar,
                    job_class=self.job_class,
                    process_line_callback=process_line_callback,
                )

        return {"dataflow_job_id": self.dataflow_job_id}
コード例 #9
0
    def execute(self, context):
        """Execute the Apache Beam Pipeline."""
        self.beam_hook = BeamHook(runner=self.runner)
        pipeline_options = self.default_pipeline_options.copy()
        process_line_callback: Optional[Callable] = None
        is_dataflow = self.runner.lower(
        ) == BeamRunnerType.DataflowRunner.lower()

        if isinstance(self.dataflow_config, dict):
            self.dataflow_config = DataflowConfiguration(
                **self.dataflow_config)

        if is_dataflow:
            self.dataflow_hook = DataflowHook(
                gcp_conn_id=self.dataflow_config.gcp_conn_id
                or self.gcp_conn_id,
                delegate_to=self.dataflow_config.delegate_to
                or self.delegate_to,
                poll_sleep=self.dataflow_config.poll_sleep,
                impersonation_chain=self.dataflow_config.impersonation_chain,
                drain_pipeline=self.dataflow_config.drain_pipeline,
                cancel_timeout=self.dataflow_config.cancel_timeout,
                wait_until_finished=self.dataflow_config.wait_until_finished,
            )
            self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id

            dataflow_job_name = DataflowHook.build_dataflow_job_name(
                self.dataflow_config.job_name,
                self.dataflow_config.append_job_name)
            pipeline_options["job_name"] = dataflow_job_name
            pipeline_options["project"] = self.dataflow_config.project_id
            pipeline_options["region"] = self.dataflow_config.location
            pipeline_options.setdefault("labels", {}).update({
                "airflow-version":
                "v" + version.replace(".", "-").replace("+", "-")
            })

            def set_current_dataflow_job_id(job_id):
                self.dataflow_job_id = job_id

            process_line_callback = process_line_and_extract_dataflow_job_id_callback(
                on_new_job_id_callback=set_current_dataflow_job_id)

        pipeline_options.update(self.pipeline_options)

        # Convert argument names from lowerCamelCase to snake case.
        formatted_pipeline_options = {
            convert_camel_to_snake(key): pipeline_options[key]
            for key in pipeline_options
        }

        with ExitStack() as exit_stack:
            if self.py_file.lower().startswith("gs://"):
                gcs_hook = GCSHook(self.gcp_conn_id, self.delegate_to)
                tmp_gcs_file = exit_stack.enter_context(  # pylint: disable=no-member
                    gcs_hook.provide_file(object_url=self.py_file))
                self.py_file = tmp_gcs_file.name

            self.beam_hook.start_python_pipeline(
                variables=formatted_pipeline_options,
                py_file=self.py_file,
                py_options=self.py_options,
                py_interpreter=self.py_interpreter,
                py_requirements=self.py_requirements,
                py_system_site_packages=self.py_system_site_packages,
                process_line_callback=process_line_callback,
            )

            if is_dataflow:
                self.dataflow_hook.wait_for_done(  # pylint: disable=no-value-for-parameter
                    job_name=dataflow_job_name,
                    location=self.dataflow_config.location,
                    job_id=self.dataflow_job_id,
                    multiple_jobs=False,
                )

        return {"dataflow_job_id": self.dataflow_job_id}