class TestDataprocHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format("CloudBaseHook.__init__"), new=mock_init): self.hook = DataprocHook(gcp_conn_id="test") @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("ClusterControllerClient")) def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_cluster_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION) }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) def test_get_template_client(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch( DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock, ) @mock.patch(DATAPROC_STRING.format("JobControllerClient")) def test_get_job_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_job_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": "{}-dataproc.googleapis.com:443".format(GCP_LOCATION) }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_create_cluster(self, mock_client): self.hook.create_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.create_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_delete_cluster(self, mock_client): self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.delete_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, cluster_uuid=None, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_diagnose_cluster(self, mock_client): self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.diagnose_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_get_cluster(self, mock_client): self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_list_clusters(self, mock_client): filter_ = "filter" self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.list_clusters.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_, page_size=None, metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_update_cluster(self, mock_client): update_mask = "update-mask" self.hook.update_cluster( project_id=GCP_PROJECT, location=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.update_cluster.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, graceful_decommission_timeout=None, metadata=None, request_id=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_create_workflow_template(self, mock_client): template = {"test": "test"} mock_client.region_path.return_value = PARENT self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.region_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION) mock_client.create_workflow_template.assert_called_once_with( parent=PARENT, template=template, retry=None, timeout=None, metadata=None) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_workflow_template(self, mock_client): template_name = "template_name" mock_client.workflow_template_path.return_value = NAME self.hook.instantiate_workflow_template(location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT) mock_client.workflow_template_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION, template_name) mock_client.instantiate_workflow_template.assert_called_once_with( name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_inline_workflow_template(self, mock_client): template = {"test": "test"} mock_client.region_path.return_value = PARENT self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.region_path.assert_called_once_with( GCP_PROJECT, GCP_LOCATION) mock_client.instantiate_inline_workflow_template.assert_called_once_with( parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job")) def test_wait_for_job(self, mock_get_job): mock_get_job.side_effect = [ mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)), mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)), ] with self.assertRaises(AirflowException): self.hook.wait_for_job( job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0, ) calls = [ mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), ] mock_get_job.has_calls(calls) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_get_job(self, mock_client): self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_job.assert_called_once_with( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_submit_job(self, mock_client): self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.submit_job.assert_called_once_with( region=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT, request_id=None, retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job")) @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job")) def test_submit(self, mock_submit_job, mock_wait_for_job): mock_submit_job.return_value.reference.job_id = JOB_ID with self.assertWarns(DeprecationWarning): self.hook.submit(project_id=GCP_PROJECT, job=JOB, region=GCP_LOCATION) mock_submit_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB) mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job_id=JOB_ID) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job(self, mock_client): self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.cancel_job.assert_called_once_with( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, retry=None, timeout=None, metadata=None, )
class TestDataprocHook(unittest.TestCase): def setUp(self): with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_init): self.hook = DataprocHook(gcp_conn_id="test") @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("ClusterControllerClient")) def test_get_cluster_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_cluster_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443" }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) def test_get_template_client_global(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client() mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient")) def test_get_template_client_region(self, mock_client, mock_client_info, mock_get_credentials): _ = self.hook.get_template_client(location='region1') mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ 'api_endpoint': 'region1-dataproc.googleapis.com:443' }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials")) @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"), new_callable=mock.PropertyMock) @mock.patch(DATAPROC_STRING.format("JobControllerClient")) def test_get_job_client(self, mock_client, mock_client_info, mock_get_credentials): self.hook.get_job_client(location=GCP_LOCATION) mock_client.assert_called_once_with( credentials=mock_get_credentials.return_value, client_info=mock_client_info.return_value, client_options={ "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443" }, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_create_cluster(self, mock_client): self.hook.create_cluster( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, cluster_config=CLUSTER_CONFIG, labels=LABELS, ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.create_cluster.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, request_id=None, ), metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_delete_cluster(self, mock_client): self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.delete_cluster.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, cluster_uuid=None, request_id=None, ), metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_diagnose_cluster(self, mock_client): self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.diagnose_cluster.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, ), metadata=None, retry=None, timeout=None, ) mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with( ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_get_cluster(self, mock_client): self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_cluster.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME, ), metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_list_clusters(self, mock_client): filter_ = "filter" self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.list_clusters.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, filter=filter_, page_size=None, ), metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_update_cluster(self, mock_client): update_mask = "update-mask" self.hook.update_cluster( project_id=GCP_PROJECT, location=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.update_cluster.assert_called_once_with( request=dict( project_id=GCP_PROJECT, region=GCP_LOCATION, cluster=CLUSTER, cluster_name=CLUSTER_NAME, update_mask=update_mask, graceful_decommission_timeout=None, request_id=None, ), metadata=None, retry=None, timeout=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_create_workflow_template(self, mock_client): template = {"test": "test"} parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}' self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.return_value.create_workflow_template.assert_called_once_with( request=dict(parent=parent, template=template), retry=None, timeout=None, metadata=()) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_workflow_template(self, mock_client): template_name = "template_name" name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}' self.hook.instantiate_workflow_template(location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT) mock_client.return_value.instantiate_workflow_template.assert_called_once_with( request=dict(name=name, version=None, parameters=None, request_id=None), retry=None, timeout=None, metadata=(), ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_inline_workflow_template(self, mock_client): template = {"test": "test"} parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}' self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with( request=dict(parent=parent, template=template, request_id=None), retry=None, timeout=None, metadata=(), ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job")) def test_wait_for_job(self, mock_get_job): mock_get_job.side_effect = [ mock.MagicMock(status=mock.MagicMock( state=JobStatus.State.RUNNING)), mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.ERROR)), ] with pytest.raises(AirflowException): self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0) calls = [ mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), mock.call(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT), ] mock_get_job.has_calls(calls) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_get_job(self, mock_client): self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_job.assert_called_once_with( request=dict( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, ), retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_submit_job(self, mock_client): self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.submit_job.assert_called_once_with( request=dict( region=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT, request_id=None, ), retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job")) @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job")) def test_submit(self, mock_submit_job, mock_wait_for_job): mock_submit_job.return_value.reference.job_id = JOB_ID with pytest.warns(DeprecationWarning): self.hook.submit(project_id=GCP_PROJECT, job=JOB, region=GCP_LOCATION) mock_submit_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job=JOB) mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION, project_id=GCP_PROJECT, job_id=JOB_ID) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job(self, mock_client): self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.cancel_job.assert_called_once_with( request=dict( region=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT, ), retry=None, timeout=None, metadata=None, ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client")) def test_cancel_job_deprecation_warning(self, mock_client): with pytest.warns(DeprecationWarning): self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location='global') mock_client.return_value.cancel_job.assert_called_once_with( request=dict( region='global', job_id=JOB_ID, project_id=GCP_PROJECT, ), retry=None, timeout=None, metadata=None, )