def execute(self, context): hook = CloudAutoMLHook(gcp_conn_id=self.gcp_conn_id) self.log.info("Requesting column specs.") page_iterator = hook.list_column_specs( dataset_id=self.dataset_id, table_spec_id=self.table_spec_id, field_mask=self.field_mask, filter_=self.filter_, page_size=self.page_size, location=self.location, project_id=self.project_id, retry=self.retry, timeout=self.timeout, metadata=self.metadata, ) result = [MessageToDict(spec) for spec in page_iterator] self.log.info("Columns specs obtained.") return result
class TestAuoMLHook(unittest.TestCase): def setUp(self) -> None: with mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.__init__", new=mock_base_gcp_hook_no_default_project_id, ): self.hook = CloudAutoMLHook() self.hook._get_credentials = mock.MagicMock( # type: ignore return_value=CREDENTIALS) @mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.client_info", new_callable=lambda: CLIENT_INFO, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient") def test_get_conn(self, mock_automl_client, mock_client_info): self.hook.get_conn() mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch( "airflow.gcp.hooks.automl.GoogleCloudBaseHook.client_info", new_callable=lambda: CLIENT_INFO, ) @mock.patch("airflow.gcp.hooks.automl.PredictionServiceClient") def test_prediction_client(self, mock_prediction_client, mock_client_info): client = self.hook.prediction_client # pylint: disable=unused-variable # noqa mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.create_model") def test_create_model(self, mock_create_model): self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_model.assert_called_once_with(parent=LOCATION_PATH, model=MODEL, retry=None, timeout=None, metadata=None) @mock.patch( "airflow.gcp.hooks.automl.PredictionServiceClient.batch_predict") def test_batch_predict(self, mock_batch_predict): self.hook.batch_predict( model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, ) mock_batch_predict.assert_called_once_with( name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.PredictionServiceClient.predict") def test_predict(self, mock_predict): self.hook.predict( model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, payload=PAYLOAD, ) mock_predict.assert_called_once_with( name=MODEL_PATH, payload=PAYLOAD, params=None, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.create_dataset") def test_create_dataset(self, mock_create_dataset): self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_create_dataset.assert_called_once_with( parent=LOCATION_PATH, dataset=DATASET, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.import_data") def test_import_dataset(self, mock_import_data): self.hook.import_data( dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, input_config=INPUT_CONFIG, ) mock_import_data.assert_called_once_with( name=DATASET_PATH, input_config=INPUT_CONFIG, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_column_specs") def test_list_column_specs(self, mock_list_column_specs): table_spec = "table_spec_id" filter_ = "filter" page_size = 42 self.hook.list_column_specs( dataset_id=DATASET_ID, table_spec_id=table_spec, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, field_mask=MASK, filter_=filter_, page_size=page_size, ) parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) mock_list_column_specs.assert_called_once_with( parent=parent, field_mask=MASK, filter_=filter_, page_size=page_size, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.get_model") def test_get_model(self, mock_get_model): self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_get_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.delete_model") def test_delete_model(self, mock_delete_model): self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_model.assert_called_once_with(name=MODEL_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.update_dataset") def test_update_dataset(self, mock_update_dataset): self.hook.update_dataset(dataset=DATASET, update_mask=MASK, project_id=GCP_PROJECT_ID) mock_update_dataset.assert_called_once_with(dataset=DATASET, update_mask=MASK, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.deploy_model") def test_deploy_model(self, mock_deploy_model): image_detection_metadata = {} self.hook.deploy_model( model_id=MODEL_ID, image_detection_metadata=image_detection_metadata, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, ) mock_deploy_model.assert_called_once_with( name=MODEL_PATH, retry=None, timeout=None, metadata=None, image_object_detection_model_deployment_metadata= image_detection_metadata, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_table_specs") def test_list_table_specs(self, mock_list_table_specs): filter_ = "filter" page_size = 42 self.hook.list_table_specs( dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID, filter_=filter_, page_size=page_size, ) mock_list_table_specs.assert_called_once_with( parent=DATASET_PATH, filter_=filter_, page_size=page_size, retry=None, timeout=None, metadata=None, ) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.list_datasets") def test_list_datasets(self, mock_list_datasets): self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_list_datasets.assert_called_once_with(parent=LOCATION_PATH, retry=None, timeout=None, metadata=None) @mock.patch("airflow.gcp.hooks.automl.AutoMlClient.delete_dataset") def test_delete_dataset(self, mock_delete_dataset): self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) mock_delete_dataset.assert_called_once_with(name=DATASET_PATH, retry=None, timeout=None, metadata=None)