Example #1
0
 def test_beam_options_to_args(self, options, expected_args):
     args = beam_options_to_args(options)
     assert args == expected_args
Example #2
0
    def start_sql_job(
        self,
        job_name: str,
        query: str,
        options: Dict[str, Any],
        project_id: str,
        location: str = DEFAULT_DATAFLOW_LOCATION,
        on_new_job_id_callback: Optional[Callable[[str], None]] = None,
    ):
        """
        Starts Dataflow SQL query.

        :param job_name: The unique name to assign to the Cloud Dataflow job.
        :type job_name: str
        :param query: The SQL query to execute.
        :type query: str
        :param options: Job parameters to be executed.
            For more information, look at:
            `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query
            <gcloud beta dataflow sql query>`__
            command reference
        :param location: The location of the Dataflow job (for example europe-west1)
        :type location: str
        :param project_id: The ID of the GCP project that owns the job.
            If set to ``None`` or missing, the default project_id from the GCP connection is used.
        :type project_id: Optional[str]
        :param on_new_job_id_callback: Callback called when the job ID is known.
        :type on_new_job_id_callback: callable
        :return: the new job object
        """
        cmd = [
            "gcloud",
            "dataflow",
            "sql",
            "query",
            query,
            f"--project={project_id}",
            "--format=value(job.id)",
            f"--job-name={job_name}",
            f"--region={location}",
            *(beam_options_to_args(options)),
        ]
        self.log.info("Executing command: %s",
                      " ".join(shlex.quote(c) for c in cmd))
        with self.provide_authorized_gcloud():
            proc = subprocess.run(cmd,
                                  stdout=subprocess.PIPE,
                                  stderr=subprocess.PIPE)
        self.log.info("Output: %s", proc.stdout.decode())
        self.log.warning("Stderr: %s", proc.stderr.decode())
        self.log.info("Exit code %d", proc.returncode)
        if proc.returncode != 0:
            raise AirflowException(
                f"Process exit with non-zero exit code. Exit code: {proc.returncode}"
            )
        job_id = proc.stdout.decode().strip()

        self.log.info("Created job ID: %s", job_id)
        if on_new_job_id_callback:
            on_new_job_id_callback(job_id)

        jobs_controller = _DataflowJobsController(
            dataflow=self.get_conn(),
            project_number=project_id,
            job_id=job_id,
            location=location,
            poll_sleep=self.poll_sleep,
            num_retries=self.num_retries,
            drain_pipeline=self.drain_pipeline,
            wait_until_finished=self.wait_until_finished,
        )
        jobs_controller.wait_for_done()

        return jobs_controller.get_jobs(refresh=True)[0]
Example #3
0
    def start_sql_job(
        self,
        job_name: str,
        query: str,
        options: Dict[str, Any],
        project_id: str,
        location: str = DEFAULT_DATAFLOW_LOCATION,
        on_new_job_id_callback: Optional[Callable[[str], None]] = None,
        on_new_job_callback: Optional[Callable[[dict], None]] = None,
    ):
        """
        Starts Dataflow SQL query.

        :param job_name: The unique name to assign to the Cloud Dataflow job.
        :param query: The SQL query to execute.
        :param options: Job parameters to be executed.
            For more information, look at:
            `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query
            <gcloud beta dataflow sql query>`__
            command reference
        :param location: The location of the Dataflow job (for example europe-west1)
        :param project_id: The ID of the GCP project that owns the job.
            If set to ``None`` or missing, the default project_id from the GCP connection is used.
        :param on_new_job_id_callback: (Deprecated) Callback called when the job ID is known.
        :param on_new_job_callback: Callback called when the job is known.
        :return: the new job object
        """
        gcp_options = [
            f"--project={project_id}",
            "--format=value(job.id)",
            f"--job-name={job_name}",
            f"--region={location}",
        ]

        if self.impersonation_chain:
            if isinstance(self.impersonation_chain, str):
                impersonation_account = self.impersonation_chain
            elif len(self.impersonation_chain) == 1:
                impersonation_account = self.impersonation_chain[0]
            else:
                raise AirflowException(
                    "Chained list of accounts is not supported, please specify only one service account"
                )
            gcp_options.append(
                f"--impersonate-service-account={impersonation_account}")

        cmd = [
            "gcloud",
            "dataflow",
            "sql",
            "query",
            query,
            *gcp_options,
            *(beam_options_to_args(options)),
        ]
        self.log.info("Executing command: %s",
                      " ".join(shlex.quote(c) for c in cmd))
        with self.provide_authorized_gcloud():
            proc = subprocess.run(cmd, capture_output=True)
        self.log.info("Output: %s", proc.stdout.decode())
        self.log.warning("Stderr: %s", proc.stderr.decode())
        self.log.info("Exit code %d", proc.returncode)
        if proc.returncode != 0:
            raise AirflowException(
                f"Process exit with non-zero exit code. Exit code: {proc.returncode}"
            )
        job_id = proc.stdout.decode().strip()

        self.log.info("Created job ID: %s", job_id)

        jobs_controller = _DataflowJobsController(
            dataflow=self.get_conn(),
            project_number=project_id,
            job_id=job_id,
            location=location,
            poll_sleep=self.poll_sleep,
            num_retries=self.num_retries,
            drain_pipeline=self.drain_pipeline,
            wait_until_finished=self.wait_until_finished,
        )
        job = jobs_controller.get_jobs(refresh=True)[0]

        if on_new_job_id_callback:
            warnings.warn(
                "on_new_job_id_callback is Deprecated. Please start using on_new_job_callback",
                DeprecationWarning,
                stacklevel=3,
            )
            on_new_job_id_callback(cast(str, job.get("id")))

        if on_new_job_callback:
            on_new_job_callback(job)

        jobs_controller.wait_for_done()
        return jobs_controller.get_jobs(refresh=True)[0]