def poke(self, context: "Context") -> bool: hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) if self.wait_timeout: try: job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id) except ServerError as err: self.log.info(f"DURATION RUN: {self._duration()}") if self._duration() > self.wait_timeout: raise AirflowException( f"Timeout: dataproc job {self.dataproc_job_id} " f"is not ready after {self.wait_timeout}s") self.log.info( "Retrying. Dataproc API returned server error when waiting for job: %s", err) return False else: job = hook.get_job(job_id=self.dataproc_job_id, region=self.region, project_id=self.project_id) state = job.status.state if state == JobStatus.State.ERROR: raise AirflowException(f'Job failed:\n{job}') elif state in { JobStatus.State.CANCELLED, JobStatus.State.CANCEL_PENDING, JobStatus.State.CANCEL_STARTED, }: raise AirflowException(f'Job was cancelled:\n{job}') elif JobStatus.State.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True elif JobStatus.State.ATTEMPT_FAILURE == state: self.log.debug("Job %s attempt has failed.", self.dataproc_job_id) self.log.info("Waiting for job %s to complete.", self.dataproc_job_id) return False
def poke(self, context: dict) -> bool: hook = DataprocHook(gcp_conn_id=self.gcp_conn_id) job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id) state = job.status.state if state == JobStatus.ERROR: raise AirflowException('Job failed:\n{}'.format(job)) elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}: raise AirflowException('Job was cancelled:\n{}'.format(job)) elif JobStatus.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True elif JobStatus.ATTEMPT_FAILURE == state: self.log.debug("Job %s attempt has failed.", self.dataproc_job_id) self.log.info("Waiting for job %s to complete.", self.dataproc_job_id) return False
class TestDataprocHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init): self.hook = DataprocHook(gcp_conn_id="test") @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("ClusterControllerClient")) def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_cluster_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443" }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) def test_get_template_client_global(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client() mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) def test_get_template_client_region(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client(location='region1') mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ 'api_endpoint': 'region1-dataproc.googleapis.com:443' }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("JobControllerClient")) def test_get_job_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_job_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443" }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_create_cluster(self, mock_client): self.hook.create_cluster( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, cluster_config=CLUSTER_CONFIG, labels=LABELS, ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.create_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_delete_cluster(self, mock_client): self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.delete_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, cluster_uuid=None, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_diagnose_cluster(self, mock_client): self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.diagnose_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, metadata=None, retry=None, timeout=None, ) mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with( ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_get_cluster(self, mock_client): self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_list_clusters(self, mock_client): filter_ = "filter" self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.list_clusters.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_, page_size=None, metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_update_cluster(self, mock_client): update_mask = "update-mask" self.hook.update_cluster( project_id=GCP_PROJECT, location=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.update_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, graceful_decommission_timeout=None, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_create_workflow_template(self, mock_client): template = {"test": "test"} mock_client.return_value.region_path.return_value = PARENT self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.return_value.region_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION) mock_client.return_value.create_workflow_template.assert_called_once_with( parent=PARENT, template=template, retry=None, timeout=None, metadata=None) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_workflow_template(self, mock_client): template_name = "template_name" mock_client.return_value.workflow_template_path.return_value = NAME self.hook.instantiate_workflow_template(location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT) mock_client.return_value.workflow_template_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION, template_name) mock_client.return_value.instantiate_workflow_template.assert_called_once_with( name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_inline_workflow_template(self, mock_client): template = {"test": "test"} mock_client.return_value.region_path.return_value = PARENT self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.return_value.region_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION) mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with( parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job")) def test_wait_for_job(self, mock_get_job): mock_get_job.side_effect = [ mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)), mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)), ] with pytest.raises(AirflowException): self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0) calls = [ mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), ] mock_get_job.has_calls(calls) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_get_job(self, mock_client): self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_job.assert_called_once_with( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_submit_job(self, mock_client): self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.submit_job.assert_called_once_with( region=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT, request_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job")) @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job")) def test_submit(self, mock_submit_job, mock_wait_for_job): mock_submit_job.return_value.reference.job_id = JOB_ID with pytest.warns(DeprecationWarning): self.hook.submit(project_id=GCP_PROJECT, job=JOB, region=GCP_LOCATION) mock_submit_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB) mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job_id=JOB_ID) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job(self, mock_client): self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.cancel_job.assert_called_once_with( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job_deprecation_warning(self, mock_client): with pytest.warns(DeprecationWarning): self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location='global') mock_client.return_value.cancel_job.assert_called_once_with( region='global', job_id=JOB_ID, project_id=GCP_PROJECT, retry=None, timeout=None, metadata=None, )