コード例 #1
0
ファイル: automl.py プロジェクト: zorseti/airflow
 def execute(self, context):
     hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id)
     self.log.info("Fetch batch prediction.")
     operation = hook.batch_predict(
         model_id=self.model_id,
         input_config=self.input_config,
         output_config=self.output_config,
         project_id=self.project_id,
         location=self.location,
         params=self.params,
         retry=self.retry,
         timeout=self.timeout,
         metadata=self.metadata,
     )
     result = MessageToDict(operation.result())
     self.log.info("Batch prediction ready.")
     return result
コード例 #2
0
ファイル: automl.py プロジェクト: thesuperzapper/airflow
 def execute(self, context: 'Context'):
     hook = CloudAutoMLHook(
         gcp_conn_id=self.gcp_conn_id,
         impersonation_chain=self.impersonation_chain,
     )
     self.log.info("Fetch batch prediction.")
     operation = hook.batch_predict(
         model_id=self.model_id,
         input_config=self.input_config,
         output_config=self.output_config,
         project_id=self.project_id,
         location=self.location,
         params=self.prediction_params,
         retry=self.retry,
         timeout=self.timeout,
         metadata=self.metadata,
     )
     result = BatchPredictResult.to_dict(operation.result())
     self.log.info("Batch prediction ready.")
     return result
コード例 #3
0
ファイル: test_automl.py プロジェクト: lgov/airflow
class TestAuoMLHook(unittest.TestCase):
    def setUp(self) -> None:
        with mock.patch(
                "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__",
                new=mock_base_gcp_hook_no_default_project_id,
        ):
            self.hook = CloudAutoMLHook()
            self.hook._get_credentials = mock.MagicMock(
                return_value=CREDENTIALS)  # type: ignore

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.client_info",
        new_callable=lambda: CLIENT_INFO,
    )
    @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient")
    def test_get_conn(self, mock_automl_client, mock_client_info):
        self.hook.get_conn()
        mock_automl_client.assert_called_once_with(credentials=CREDENTIALS,
                                                   client_info=CLIENT_INFO)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.client_info",
        new_callable=lambda: CLIENT_INFO,
    )
    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.PredictionServiceClient")
    def test_prediction_client(self, mock_prediction_client, mock_client_info):
        client = self.hook.prediction_client  # pylint: disable=unused-variable  # noqa
        mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS,
                                                       client_info=CLIENT_INFO)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model"
    )
    def test_create_model(self, mock_create_model):
        self.hook.create_model(model=MODEL,
                               location=GCP_LOCATION,
                               project_id=GCP_PROJECT_ID)

        mock_create_model.assert_called_once_with(parent=LOCATION_PATH,
                                                  model=MODEL,
                                                  retry=None,
                                                  timeout=None,
                                                  metadata=None)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict"
    )
    def test_batch_predict(self, mock_batch_predict):
        self.hook.batch_predict(
            model_id=MODEL_ID,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            input_config=INPUT_CONFIG,
            output_config=OUTPUT_CONFIG,
        )

        mock_batch_predict.assert_called_once_with(
            name=MODEL_PATH,
            input_config=INPUT_CONFIG,
            output_config=OUTPUT_CONFIG,
            params=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict"
    )
    def test_predict(self, mock_predict):
        self.hook.predict(
            model_id=MODEL_ID,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            payload=PAYLOAD,
        )

        mock_predict.assert_called_once_with(
            name=MODEL_PATH,
            payload=PAYLOAD,
            params=None,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset"
    )
    def test_create_dataset(self, mock_create_dataset):
        self.hook.create_dataset(dataset=DATASET,
                                 location=GCP_LOCATION,
                                 project_id=GCP_PROJECT_ID)

        mock_create_dataset.assert_called_once_with(
            parent=LOCATION_PATH,
            dataset=DATASET,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
    def test_import_dataset(self, mock_import_data):
        self.hook.import_data(
            dataset_id=DATASET_ID,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            input_config=INPUT_CONFIG,
        )

        mock_import_data.assert_called_once_with(
            name=DATASET_PATH,
            input_config=INPUT_CONFIG,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs"
    )
    def test_list_column_specs(self, mock_list_column_specs):
        table_spec = "table_spec_id"
        filter_ = "filter"
        page_size = 42

        self.hook.list_column_specs(
            dataset_id=DATASET_ID,
            table_spec_id=table_spec,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            field_mask=MASK,
            filter_=filter_,
            page_size=page_size,
        )

        parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION,
                                              DATASET_ID, table_spec)
        mock_list_column_specs.assert_called_once_with(
            parent=parent,
            field_mask=MASK,
            filter_=filter_,
            page_size=page_size,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
    def test_get_model(self, mock_get_model):
        self.hook.get_model(model_id=MODEL_ID,
                            location=GCP_LOCATION,
                            project_id=GCP_PROJECT_ID)

        mock_get_model.assert_called_once_with(name=MODEL_PATH,
                                               retry=None,
                                               timeout=None,
                                               metadata=None)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model"
    )
    def test_delete_model(self, mock_delete_model):
        self.hook.delete_model(model_id=MODEL_ID,
                               location=GCP_LOCATION,
                               project_id=GCP_PROJECT_ID)

        mock_delete_model.assert_called_once_with(name=MODEL_PATH,
                                                  retry=None,
                                                  timeout=None,
                                                  metadata=None)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset"
    )
    def test_update_dataset(self, mock_update_dataset):
        self.hook.update_dataset(
            dataset=DATASET,
            update_mask=MASK,
        )

        mock_update_dataset.assert_called_once_with(dataset=DATASET,
                                                    update_mask=MASK,
                                                    retry=None,
                                                    timeout=None,
                                                    metadata=None)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model"
    )
    def test_deploy_model(self, mock_deploy_model):
        image_detection_metadata = {}

        self.hook.deploy_model(
            model_id=MODEL_ID,
            image_detection_metadata=image_detection_metadata,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
        )

        mock_deploy_model.assert_called_once_with(
            name=MODEL_PATH,
            retry=None,
            timeout=None,
            metadata=None,
            image_object_detection_model_deployment_metadata=
            image_detection_metadata,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs"
    )
    def test_list_table_specs(self, mock_list_table_specs):
        filter_ = "filter"
        page_size = 42

        self.hook.list_table_specs(
            dataset_id=DATASET_ID,
            location=GCP_LOCATION,
            project_id=GCP_PROJECT_ID,
            filter_=filter_,
            page_size=page_size,
        )

        mock_list_table_specs.assert_called_once_with(
            parent=DATASET_PATH,
            filter_=filter_,
            page_size=page_size,
            retry=None,
            timeout=None,
            metadata=None,
        )

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets"
    )
    def test_list_datasets(self, mock_list_datasets):
        self.hook.list_datasets(location=GCP_LOCATION,
                                project_id=GCP_PROJECT_ID)

        mock_list_datasets.assert_called_once_with(parent=LOCATION_PATH,
                                                   retry=None,
                                                   timeout=None,
                                                   metadata=None)

    @mock.patch(
        "airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset"
    )
    def test_delete_dataset(self, mock_delete_dataset):
        self.hook.delete_dataset(dataset_id=DATASET_ID,
                                 location=GCP_LOCATION,
                                 project_id=GCP_PROJECT_ID)

        mock_delete_dataset.assert_called_once_with(name=DATASET_PATH,
                                                    retry=None,
                                                    timeout=None,
                                                    metadata=None)