示例#1
0
class TestDataprocHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format("CloudBaseHook.__init__"),
                        new=mock_init):
            self.hook = DataprocHook(gcp_conn_id="test")

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(
        DATAPROC_STRING.format("DataprocHook.client_info"),
        new_callable=mock.PropertyMock,
    )
    @mock.patch(DATAPROC_STRING.format("ClusterControllerClient"))
    def test_get_cluster_client(self, mock_client, mock_client_info,
                                mock_get_credentials):
        self.hook.get_cluster_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint":
                "{}-dataproc.googleapis.com:443".format(GCP_LOCATION)
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(
        DATAPROC_STRING.format("DataprocHook.client_info"),
        new_callable=mock.PropertyMock,
    )
    @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
    def test_get_template_client(self, mock_client, mock_client_info,
                                 mock_get_credentials):
        _ = self.hook.get_template_client
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(
        DATAPROC_STRING.format("DataprocHook.client_info"),
        new_callable=mock.PropertyMock,
    )
    @mock.patch(DATAPROC_STRING.format("JobControllerClient"))
    def test_get_job_client(self, mock_client, mock_client_info,
                            mock_get_credentials):
        self.hook.get_job_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint":
                "{}-dataproc.googleapis.com:443".format(GCP_LOCATION)
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_create_cluster(self, mock_client):
        self.hook.create_cluster(project_id=GCP_PROJECT,
                                 region=GCP_LOCATION,
                                 cluster=CLUSTER)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.create_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster=CLUSTER,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_delete_cluster(self, mock_client):
        self.hook.delete_cluster(project_id=GCP_PROJECT,
                                 region=GCP_LOCATION,
                                 cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.delete_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            cluster_uuid=None,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_diagnose_cluster(self, mock_client):
        self.hook.diagnose_cluster(project_id=GCP_PROJECT,
                                   region=GCP_LOCATION,
                                   cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.diagnose_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_get_cluster(self, mock_client):
        self.hook.get_cluster(project_id=GCP_PROJECT,
                              region=GCP_LOCATION,
                              cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_list_clusters(self, mock_client):
        filter_ = "filter"

        self.hook.list_clusters(project_id=GCP_PROJECT,
                                region=GCP_LOCATION,
                                filter_=filter_)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.list_clusters.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            filter_=filter_,
            page_size=None,
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_update_cluster(self, mock_client):
        update_mask = "update-mask"
        self.hook.update_cluster(
            project_id=GCP_PROJECT,
            location=GCP_LOCATION,
            cluster=CLUSTER,
            cluster_name=CLUSTER_NAME,
            update_mask=update_mask,
        )
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.update_cluster.assert_called_once_with(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster=CLUSTER,
            cluster_name=CLUSTER_NAME,
            update_mask=update_mask,
            graceful_decommission_timeout=None,
            metadata=None,
            request_id=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_create_workflow_template(self, mock_client):
        template = {"test": "test"}
        mock_client.region_path.return_value = PARENT
        self.hook.create_workflow_template(location=GCP_LOCATION,
                                           template=template,
                                           project_id=GCP_PROJECT)
        mock_client.region_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION)
        mock_client.create_workflow_template.assert_called_once_with(
            parent=PARENT,
            template=template,
            retry=None,
            timeout=None,
            metadata=None)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_workflow_template(self, mock_client):
        template_name = "template_name"
        mock_client.workflow_template_path.return_value = NAME
        self.hook.instantiate_workflow_template(location=GCP_LOCATION,
                                                template_name=template_name,
                                                project_id=GCP_PROJECT)
        mock_client.workflow_template_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION, template_name)
        mock_client.instantiate_workflow_template.assert_called_once_with(
            name=NAME,
            version=None,
            parameters=None,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_inline_workflow_template(self, mock_client):
        template = {"test": "test"}
        mock_client.region_path.return_value = PARENT
        self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION,
                                                       template=template,
                                                       project_id=GCP_PROJECT)
        mock_client.region_path.assert_called_once_with(
            GCP_PROJECT, GCP_LOCATION)
        mock_client.instantiate_inline_workflow_template.assert_called_once_with(
            parent=PARENT,
            template=template,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
    def test_wait_for_job(self, mock_get_job):
        mock_get_job.side_effect = [
            mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)),
            mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)),
        ]
        with self.assertRaises(AirflowException):
            self.hook.wait_for_job(
                job_id=JOB_ID,
                location=GCP_LOCATION,
                project_id=GCP_PROJECT,
                wait_time=0,
            )
        calls = [
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
        ]
        mock_get_job.has_calls(calls)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_get_job(self, mock_client):
        self.hook.get_job(location=GCP_LOCATION,
                          job_id=JOB_ID,
                          project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_job.assert_called_once_with(
            region=GCP_LOCATION,
            job_id=JOB_ID,
            project_id=GCP_PROJECT,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_submit_job(self, mock_client):
        self.hook.submit_job(location=GCP_LOCATION,
                             job=JOB,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.submit_job.assert_called_once_with(
            region=GCP_LOCATION,
            job=JOB,
            project_id=GCP_PROJECT,
            request_id=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job"))
    def test_submit(self, mock_submit_job, mock_wait_for_job):
        mock_submit_job.return_value.reference.job_id = JOB_ID
        with self.assertWarns(DeprecationWarning):
            self.hook.submit(project_id=GCP_PROJECT,
                             job=JOB,
                             region=GCP_LOCATION)
        mock_submit_job.assert_called_once_with(location=GCP_LOCATION,
                                                project_id=GCP_PROJECT,
                                                job=JOB)
        mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION,
                                                  project_id=GCP_PROJECT,
                                                  job_id=JOB_ID)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_cancel_job(self, mock_client):
        self.hook.cancel_job(location=GCP_LOCATION,
                             job_id=JOB_ID,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.cancel_job.assert_called_once_with(
            region=GCP_LOCATION,
            job_id=JOB_ID,
            project_id=GCP_PROJECT,
            retry=None,
            timeout=None,
            metadata=None,
        )
示例#2
0
class TestDataprocHook(unittest.TestCase):
    def setUp(self):
        with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"),
                        new=mock_init):
            self.hook = DataprocHook(gcp_conn_id="test")

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("ClusterControllerClient"))
    def test_get_cluster_client(self, mock_client, mock_client_info,
                                mock_get_credentials):
        self.hook.get_cluster_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
    def test_get_template_client_global(self, mock_client, mock_client_info,
                                        mock_get_credentials):
        _ = self.hook.get_template_client()
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceClient"))
    def test_get_template_client_region(self, mock_client, mock_client_info,
                                        mock_get_credentials):
        _ = self.hook.get_template_client(location='region1')
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                'api_endpoint': 'region1-dataproc.googleapis.com:443'
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook._get_credentials"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.client_info"),
                new_callable=mock.PropertyMock)
    @mock.patch(DATAPROC_STRING.format("JobControllerClient"))
    def test_get_job_client(self, mock_client, mock_client_info,
                            mock_get_credentials):
        self.hook.get_job_client(location=GCP_LOCATION)
        mock_client.assert_called_once_with(
            credentials=mock_get_credentials.return_value,
            client_info=mock_client_info.return_value,
            client_options={
                "api_endpoint": f"{GCP_LOCATION}-dataproc.googleapis.com:443"
            },
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_create_cluster(self, mock_client):
        self.hook.create_cluster(
            project_id=GCP_PROJECT,
            region=GCP_LOCATION,
            cluster_name=CLUSTER_NAME,
            cluster_config=CLUSTER_CONFIG,
            labels=LABELS,
        )
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.create_cluster.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                cluster=CLUSTER,
                request_id=None,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_delete_cluster(self, mock_client):
        self.hook.delete_cluster(project_id=GCP_PROJECT,
                                 region=GCP_LOCATION,
                                 cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.delete_cluster.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                cluster_name=CLUSTER_NAME,
                cluster_uuid=None,
                request_id=None,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_diagnose_cluster(self, mock_client):
        self.hook.diagnose_cluster(project_id=GCP_PROJECT,
                                   region=GCP_LOCATION,
                                   cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.diagnose_cluster.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                cluster_name=CLUSTER_NAME,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )
        mock_client.return_value.diagnose_cluster.return_value.result.assert_called_once_with(
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_get_cluster(self, mock_client):
        self.hook.get_cluster(project_id=GCP_PROJECT,
                              region=GCP_LOCATION,
                              cluster_name=CLUSTER_NAME)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_cluster.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                cluster_name=CLUSTER_NAME,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_list_clusters(self, mock_client):
        filter_ = "filter"

        self.hook.list_clusters(project_id=GCP_PROJECT,
                                region=GCP_LOCATION,
                                filter_=filter_)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.list_clusters.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                filter=filter_,
                page_size=None,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
    def test_update_cluster(self, mock_client):
        update_mask = "update-mask"
        self.hook.update_cluster(
            project_id=GCP_PROJECT,
            location=GCP_LOCATION,
            cluster=CLUSTER,
            cluster_name=CLUSTER_NAME,
            update_mask=update_mask,
        )
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.update_cluster.assert_called_once_with(
            request=dict(
                project_id=GCP_PROJECT,
                region=GCP_LOCATION,
                cluster=CLUSTER,
                cluster_name=CLUSTER_NAME,
                update_mask=update_mask,
                graceful_decommission_timeout=None,
                request_id=None,
            ),
            metadata=None,
            retry=None,
            timeout=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_create_workflow_template(self, mock_client):
        template = {"test": "test"}
        parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
        self.hook.create_workflow_template(location=GCP_LOCATION,
                                           template=template,
                                           project_id=GCP_PROJECT)
        mock_client.return_value.create_workflow_template.assert_called_once_with(
            request=dict(parent=parent, template=template),
            retry=None,
            timeout=None,
            metadata=())

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_workflow_template(self, mock_client):
        template_name = "template_name"
        name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}'
        self.hook.instantiate_workflow_template(location=GCP_LOCATION,
                                                template_name=template_name,
                                                project_id=GCP_PROJECT)
        mock_client.return_value.instantiate_workflow_template.assert_called_once_with(
            request=dict(name=name,
                         version=None,
                         parameters=None,
                         request_id=None),
            retry=None,
            timeout=None,
            metadata=(),
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
    def test_instantiate_inline_workflow_template(self, mock_client):
        template = {"test": "test"}
        parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}'
        self.hook.instantiate_inline_workflow_template(location=GCP_LOCATION,
                                                       template=template,
                                                       project_id=GCP_PROJECT)
        mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with(
            request=dict(parent=parent, template=template, request_id=None),
            retry=None,
            timeout=None,
            metadata=(),
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job"))
    def test_wait_for_job(self, mock_get_job):
        mock_get_job.side_effect = [
            mock.MagicMock(status=mock.MagicMock(
                state=JobStatus.State.RUNNING)),
            mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.ERROR)),
        ]
        with pytest.raises(AirflowException):
            self.hook.wait_for_job(job_id=JOB_ID,
                                   location=GCP_LOCATION,
                                   project_id=GCP_PROJECT,
                                   wait_time=0)
        calls = [
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
            mock.call(location=GCP_LOCATION,
                      job_id=JOB_ID,
                      project_id=GCP_PROJECT),
        ]
        mock_get_job.has_calls(calls)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_get_job(self, mock_client):
        self.hook.get_job(location=GCP_LOCATION,
                          job_id=JOB_ID,
                          project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.get_job.assert_called_once_with(
            request=dict(
                region=GCP_LOCATION,
                job_id=JOB_ID,
                project_id=GCP_PROJECT,
            ),
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_submit_job(self, mock_client):
        self.hook.submit_job(location=GCP_LOCATION,
                             job=JOB,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.submit_job.assert_called_once_with(
            request=dict(
                region=GCP_LOCATION,
                job=JOB,
                project_id=GCP_PROJECT,
                request_id=None,
            ),
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.wait_for_job"))
    @mock.patch(DATAPROC_STRING.format("DataprocHook.submit_job"))
    def test_submit(self, mock_submit_job, mock_wait_for_job):
        mock_submit_job.return_value.reference.job_id = JOB_ID
        with pytest.warns(DeprecationWarning):
            self.hook.submit(project_id=GCP_PROJECT,
                             job=JOB,
                             region=GCP_LOCATION)
        mock_submit_job.assert_called_once_with(location=GCP_LOCATION,
                                                project_id=GCP_PROJECT,
                                                job=JOB)
        mock_wait_for_job.assert_called_once_with(location=GCP_LOCATION,
                                                  project_id=GCP_PROJECT,
                                                  job_id=JOB_ID)

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_cancel_job(self, mock_client):
        self.hook.cancel_job(location=GCP_LOCATION,
                             job_id=JOB_ID,
                             project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location=GCP_LOCATION)
        mock_client.return_value.cancel_job.assert_called_once_with(
            request=dict(
                region=GCP_LOCATION,
                job_id=JOB_ID,
                project_id=GCP_PROJECT,
            ),
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job_client"))
    def test_cancel_job_deprecation_warning(self, mock_client):
        with pytest.warns(DeprecationWarning):
            self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT)
        mock_client.assert_called_once_with(location='global')
        mock_client.return_value.cancel_job.assert_called_once_with(
            request=dict(
                region='global',
                job_id=JOB_ID,
                project_id=GCP_PROJECT,
            ),
            retry=None,
            timeout=None,
            metadata=None,
        )