示例#1
0
    def poke(self, context: "Context") -> bool:
        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
        if self.wait_timeout:
            try:
                job = hook.get_job(job_id=self.dataproc_job_id,
                                   region=self.region,
                                   project_id=self.project_id)
            except ServerError as err:
                self.log.info(f"DURATION RUN: {self._duration()}")
                if self._duration() > self.wait_timeout:
                    raise AirflowException(
                        f"Timeout: dataproc job {self.dataproc_job_id} "
                        f"is not ready after {self.wait_timeout}s")
                self.log.info(
                    "Retrying. Dataproc API returned server error when waiting for job: %s",
                    err)
                return False
        else:
            job = hook.get_job(job_id=self.dataproc_job_id,
                               region=self.region,
                               project_id=self.project_id)

        state = job.status.state
        if state == JobStatus.State.ERROR:
            raise AirflowException(f'Job failed:\n{job}')
        elif state in {
                JobStatus.State.CANCELLED,
                JobStatus.State.CANCEL_PENDING,
                JobStatus.State.CANCEL_STARTED,
        }:
            raise AirflowException(f'Job was cancelled:\n{job}')
        elif JobStatus.State.DONE == state:
            self.log.debug("Job %s completed successfully.",
                           self.dataproc_job_id)
            return True
        elif JobStatus.State.ATTEMPT_FAILURE == state:
            self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)

        self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
        return False
示例#2
0
文件: dataproc.py 项目: lgov/airflow
    def poke(self, context: dict) -> bool:
        hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
        job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id)
        state = job.status.state

        if state == JobStatus.ERROR:
            raise AirflowException('Job failed:\n{}'.format(job))
        elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}:
            raise AirflowException('Job was cancelled:\n{}'.format(job))
        elif JobStatus.DONE == state:
            self.log.debug("Job %s completed successfully.", self.dataproc_job_id)
            return True
        elif JobStatus.ATTEMPT_FAILURE == state:
            self.log.debug("Job %s attempt has failed.", self.dataproc_job_id)

        self.log.info("Waiting for job %s to complete.", self.dataproc_job_id)
        return False
示例#3
0
class TestDataprocHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"),
                        new=mock_init):
            self.hook = DataprocHook(gcp_conn_id="test")

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("ClusterControllerClient"))
    def test_get_cluster_client(self, mock_client, mock_client_info,
                                mock_get_credentials):
        self.hook.get_cluster_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
    def test_get_template_client_global(self, mock_client, mock_client_info,
                                        mock_get_credentials):
        _ = self.hook.get_template_client()
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
    def test_get_template_client_region(self, mock_client, mock_client_info,
                                        mock_get_credentials):
        _ = self.hook.get_template_client(location='region1')
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                'api_endpoint': 'region1-dataproc.googleapis.com:443'
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("JobControllerClient"))
    def test_get_job_client(self, mock_client, mock_client_info,
                            mock_get_credentials):
        self.hook.get_job_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_create_cluster(self, mock_client):
        self.hook.create_cluster(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            cluster_config=CLUSTER_CONFIG,
            labels=LABELS,
        )
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.create_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster=CLUSTER,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_delete_cluster(self, mock_client):
        self.hook.delete_cluster(project_id=GCP_PROJECT,
                                 region=GCP_LOCATION,
                                 cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.delete_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            cluster_uuid=None,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_diagnose_cluster(self, mock_client):
        self.hook.diagnose_cluster(project_id=GCP_PROJECT,
                                   region=GCP_LOCATION,
                                   cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.diagnose_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            metadata=None,
            retry=None,
            timeout=None,
        )
        mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with(
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_get_cluster(self, mock_client):
        self.hook.get_cluster(project_id=GCP_PROJECT,
                              region=GCP_LOCATION,
                              cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_list_clusters(self, mock_client):
        filter_ = "filter"

        self.hook.list_clusters(project_id=GCP_PROJECT,
                                region=GCP_LOCATION,
                                filter_=filter_)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.list_clusters.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            filter_=filter_,
            page_size=None,
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_update_cluster(self, mock_client):
        update_mask = "update-mask"
        self.hook.update_cluster(
            project_id=GCP_PROJECT,
            location=GCP_LOCATION,
            cluster=CLUSTER,
            cluster_name=CLUSTER_NAME,
            update_mask=update_mask,
        )
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.update_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster=CLUSTER,
            cluster_name=CLUSTER_NAME,
            update_mask=update_mask,
            graceful_decommission_timeout=None,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_create_workflow_template(self, mock_client):
        template = {"test": "test"}
        mock_client.return_value.region_path.return_value = PARENT
        self.hook.create_workflow_template(location=GCP_LOCATION,
                                           template=template,
                                           project_id=GCP_PROJECT)
        mock_client.return_value.region_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION)
        mock_client.return_value.create_workflow_template.assert_called_once_with(
            parent=PARENT,
            template=template,
            retry=None,
            timeout=None,
            metadata=None)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_workflow_template(self, mock_client):
        template_name = "template_name"
        mock_client.return_value.workflow_template_path.return_value = NAME
        self.hook.instantiate_workflow_template(location=GCP_LOCATION,
                                                template_name=template_name,
                                                project_id=GCP_PROJECT)
        mock_client.return_value.workflow_template_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION, template_name)
        mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
            name=NAME,
            version=None,
            parameters=None,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_inline_workflow_template(self, mock_client):
        template = {"test": "test"}
        mock_client.return_value.region_path.return_value = PARENT
        self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION,
                                                       template=template,
                                                       project_id=GCP_PROJECT)
        mock_client.return_value.region_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION)
        mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
            parent=PARENT,
            template=template,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
    def test_wait_for_job(self, mock_get_job):
        mock_get_job.side_effect = [
            mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
            mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
        ]
        with pytest.raises(AirflowException):
            self.hook.wait_for_job(job_id=JOB_ID,
                                   location=GCP_LOCATION,
                                   project_id=GCP_PROJECT,
                                   wait_time=0)
        calls = [
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
        ]
        mock_get_job.has_calls(calls)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_get_job(self, mock_client):
        self.hook.get_job(location=GCP_LOCATION,
                          job_id=JOB_ID,
                          project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_job.assert_called_once_with(
            region=GCP_LOCATION,
            job_id=JOB_ID,
            project_id=GCP_PROJECT,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_submit_job(self, mock_client):
        self.hook.submit_job(location=GCP_LOCATION,
                             job=JOB,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.submit_job.assert_called_once_with(
            region=GCP_LOCATION,
            job=JOB,
            project_id=GCP_PROJECT,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job"))
    def test_submit(self, mock_submit_job, mock_wait_for_job):
        mock_submit_job.return_value.reference.job_id = JOB_ID
        with pytest.warns(DeprecationWarning):
            self.hook.submit(project_id=GCP_PROJECT,
                             job=JOB,
                             region=GCP_LOCATION)
        mock_submit_job.assert_called_once_with(location=GCP_LOCATION,
                                                project_id=GCP_PROJECT,
                                                job=JOB)
        mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION,
                                                  project_id=GCP_PROJECT,
                                                  job_id=JOB_ID)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_cancel_job(self, mock_client):
        self.hook.cancel_job(location=GCP_LOCATION,
                             job_id=JOB_ID,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.cancel_job.assert_called_once_with(
            region=GCP_LOCATION,
            job_id=JOB_ID,
            project_id=GCP_PROJECT,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_cancel_job_deprecation_warning(self, mock_client):
        with pytest.warns(DeprecationWarning):
            self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location='global')
        mock_client.return_value.cancel_job.assert_called_once_with(
            region='global',
            job_id=JOB_ID,
            project_id=GCP_PROJECT,
            retry=None,
            timeout=None,
            metadata=None,
        )